1mod anthropic_client;
2mod distill;
3mod example;
4mod filter_languages;
5mod format_prompt;
6mod git;
7mod headless;
8mod kept_rate;
9mod load_project;
10mod metrics;
11mod openai_client;
12mod parse_output;
13mod paths;
14mod predict;
15mod progress;
16mod prompt_assets;
17mod pull_examples;
18mod qa;
19mod reorder_patch;
20mod repair;
21mod retrieve_context;
22mod reversal_tracking;
23mod score;
24mod split_commit;
25mod split_dataset;
26
27mod synthesize;
28mod truncate_expected_patch;
29mod word_diff;
30use anyhow::Context as _;
31use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
32use collections::{HashMap, HashSet};
33use edit_prediction::EditPredictionStore;
34use futures::channel::mpsc;
35use futures::{SinkExt as _, StreamExt as _};
36use gaoya::minhash::{
37 MinHashIndex, MinHasher, MinHasher32, calculate_minhash_params, compute_minhash_similarity,
38};
39use gpui::{AppContext as _, BackgroundExecutor, Task};
40use zeta_prompt::ZetaFormat;
41
42use reqwest_client::ReqwestClient;
43use serde::{Deserialize, Deserializer, Serialize, Serializer};
44use std::env;
45use std::fmt::Display;
46use std::fs::{File, OpenOptions};
47use std::hash::{Hash, Hasher};
48use std::io::{BufRead, BufReader, BufWriter, Write};
49use std::sync::Mutex;
50use std::{path::PathBuf, sync::Arc};
51
52use crate::distill::run_distill;
53use crate::example::{Example, group_examples_by_repo, read_example_files};
54use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
55use crate::format_prompt::run_format_prompt;
56use crate::load_project::run_load_project;
57use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
58use crate::predict::run_prediction;
59use crate::progress::Progress;
60use crate::pull_examples::{fetch_settled_examples_after, parse_settled_after_input};
61use crate::retrieve_context::run_context_retrieval;
62use crate::score::run_scoring;
63use crate::split_commit::SplitCommitArgs;
64use crate::split_dataset::SplitArgs;
65use crate::synthesize::{SynthesizeConfig, run_synthesize};
66use crate::truncate_expected_patch::TruncatePatchArgs;
67
68#[derive(Parser, Debug)]
69#[command(name = "ep")]
70struct EpArgs {
71 #[arg(long, default_value_t = false)]
72 printenv: bool,
73 #[clap(long, default_value_t = 10, global = true)]
74 max_parallelism: usize,
75 /// The limit for the number of examples to process
76 /// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
77 #[clap(long, global = true)]
78 limit: Option<usize>,
79 #[clap(long, global = true)]
80 offset: Option<usize>,
81 /// Filter examples by name
82 #[clap(long, global = true)]
83 name: Option<String>,
84 /// Filter examples by repository
85 #[clap(long, global = true)]
86 repo: Option<String>,
87 /// Deduplicate by cursor position and keep at most this many examples per cluster
88 #[clap(long, global = true)]
89 max_duplicates: Option<usize>,
90 #[command(subcommand)]
91 command: Option<Command>,
92 /// Input file paths
93 #[clap(global = true)]
94 inputs: Vec<PathBuf>,
95 #[arg(long, short, global = true)]
96 output: Option<PathBuf>,
97 #[arg(long, short, global = true)]
98 in_place: bool,
99 #[arg(long, short, global = true)]
100 failfast: bool,
101 /// How to handle failed examples in output: keep them or skip them.
102 /// Failed examples are always logged to the run's failed directory.
103 #[arg(long, global = true, default_value = "keep")]
104 failed: FailedHandling,
105 /// Output as markdown files instead of JSONL. When set, -o specifies a directory
106 /// where one .md file per example will be written (named after each example).
107 #[arg(long, short, global = true)]
108 markdown: bool,
109}
110
111/// Controls whether failed examples are included in the main output.
112/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
113#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
114pub enum FailedHandling {
115 /// Include failed examples in the main output (default)
116 #[default]
117 Keep,
118 /// Exclude failed examples from the main output
119 Skip,
120 /// Skip writing files
121 SkipNoFiles,
122}
123
124const INPUTS_HELP: &str = r#"
125Inputs can be file paths or special specifiers:
126
127 path
128 Path to an example(s) file (.md, .json, or .jsonl)
129
130 captured-after:{timestamp}
131 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
132 These are examples captured via the "Capture Edit Prediction Example" action.
133
134 rejected-after:{timestamp}
135 Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
136 These are predictions that were shown to users but rejected (useful for DPO training).
137
138 settled-after:{timestamp}
139 Fetch settled stream examples from Snowflake after the given RFC3339 timestamp.
140 These are examples from the edit prediction settled stream.
141
142 rated-after:{timestamp}
143 Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
144 These are predictions that users explicitly rated as positive or negative via the
145 rate completions modal. Only zeta2 predictions are included.
146 - Positive ratings: output becomes expected_patches
147 - Negative ratings: output becomes rejected_patch
148
149 rated-positive-after:{timestamp}
150 Same as rated-after, but only fetches positively rated predictions.
151
152 rated-negative-after:{timestamp}
153 Same as rated-after, but only fetches negatively rated predictions.
154
155 Required environment variables to connect to Snowflake:
156 EP_SNOWFLAKE_API_KEY
157 EP_SNOWFLAKE_BASE_URL
158
159 Optional:
160 EP_SNOWFLAKE_ROLE
161
162Examples:
163
164 # Read examples from a file
165 ep read examples.jsonl -o output.jsonl
166
167 # Read captured examples after a timestamp
168 ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
169
170 # Read rejected predictions for DPO training
171 ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
172
173 # Read user-rated predictions
174 ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
175
176 # Read settled stream examples
177 ep read settled-after:2025-01-01T00:00:00Z -o settled.jsonl
178
179 # Read only positively rated predictions
180 ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
181
182 # Read only negatively rated predictions
183 ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
184
185 # Mix multiple input sources
186 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
187"#;
188
189#[derive(Subcommand, Debug, Clone)]
190enum Command {
191 /// Read examples from files or fetch from Snowflake, output as .jsonl
192 Read(ReadArgs),
193 /// Create git worktrees for each example and load file contents
194 LoadProject,
195 /// Retrieve context for input examples.
196 Context,
197 /// Generate a prompt string for a specific model
198 FormatPrompt(FormatPromptArgs),
199 /// Runs edit prediction
200 Predict(PredictArgs),
201 /// Parse model outputs (actual_output) into unified diffs (actual_patch).
202 /// Requires format-prompt to have been run first. Uses provider from prompt.
203 ParseOutput,
204 /// Computes a score based on actual and expected patches
205 Score(PredictArgs),
206 /// Prepares a distillation dataset by copying expected outputs to
207 /// predicted outputs and removing actual outputs and prompts.
208 Distill,
209 /// Print aggregated scores
210 Eval(EvalArgs),
211 /// Generate eval examples by analyzing git commits from a repository
212 Synthesize(SynthesizeArgs),
213 /// Remove git repositories and worktrees
214 Clean,
215 /// Generate an evaluation example by splitting a chronologically-ordered commit
216 SplitCommit(SplitCommitArgs),
217 /// Truncate expected patch by the given criteria
218 TruncatePatch(TruncatePatchArgs),
219 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
220 Split(SplitArgs),
221 /// Filter a JSONL dataset by programming language (based on cursor_path extension)
222 FilterLanguages(FilterLanguagesArgs),
223 /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
224 ImportBatch(ImportBatchArgs),
225 /// Assess the quality of predictions using LLM-as-a-judge
226 Qa(qa::QaArgs),
227 /// Repair predictions that received poor QA scores by generating improved predictions
228 Repair(repair::RepairArgs),
229 /// Print all valid zeta formats (lowercase, one per line)
230 PrintZetaFormats,
231}
232
233impl Display for Command {
234 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
235 match self {
236 Command::Read(_) => write!(f, "read"),
237 Command::LoadProject => write!(f, "load-project"),
238 Command::Context => write!(f, "context"),
239 Command::FormatPrompt(args) => {
240 write!(f, "format-prompt --provider={}", args.provider)
241 }
242 Command::Predict(args) => match &args.provider {
243 Some(provider) => write!(f, "predict --provider={}", provider),
244 None => write!(f, "predict"),
245 },
246 Command::ParseOutput => write!(f, "parse-output"),
247 Command::Score(args) => match &args.provider {
248 Some(provider) => write!(f, "score --provider={}", provider),
249 None => write!(f, "score"),
250 },
251 Command::Distill => write!(f, "distill"),
252 Command::Eval(args) => match &args.predict.provider {
253 Some(provider) => write!(f, "eval --provider={}", provider),
254 None => write!(f, "eval"),
255 },
256 Command::Synthesize(args) => {
257 write!(f, "synthesize --repos {}", args.repos.join(" "))
258 }
259 Command::Clean => write!(f, "clean"),
260 Command::SplitCommit(_) => write!(f, "split-commit"),
261 Command::TruncatePatch(_) => write!(f, "truncate-patch"),
262 Command::Split(_) => write!(f, "split"),
263 Command::FilterLanguages(_) => write!(f, "filter-languages"),
264 Command::ImportBatch(args) => {
265 write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
266 }
267 Command::Qa(_) => {
268 write!(f, "qa")
269 }
270 Command::Repair(_) => {
271 write!(f, "repair")
272 }
273 Command::PrintZetaFormats => {
274 write!(f, "print-zeta-formats")
275 }
276 }
277 }
278}
279
280#[derive(Debug, Args, Clone)]
281#[command(after_help = INPUTS_HELP)]
282struct ReadArgs {}
283
284#[derive(Debug, Args, Clone)]
285struct FormatPromptArgs {
286 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
287 provider: PredictionProvider,
288}
289
290#[derive(Debug, Args, Clone)]
291struct PredictArgs {
292 #[clap(long, short('p'))]
293 provider: Option<PredictionProvider>,
294 #[clap(long, default_value_t = 1)]
295 repetitions: usize,
296 /// Only use cached responses, don't queue new requests for batching
297 #[clap(long)]
298 cache_only: bool,
299 /// Wait for all batches to complete before exiting (only applies to batched providers like teacher)
300 #[clap(long)]
301 wait: bool,
302}
303
304#[derive(Debug, Args, Clone)]
305struct EvalArgs {
306 #[clap(flatten)]
307 predict: PredictArgs,
308 /// Path to write summary scores as JSON
309 #[clap(long)]
310 summary_json: Option<PathBuf>,
311 /// Print all individual example lines (default: up to 20)
312 #[clap(long)]
313 verbose: bool,
314}
315
316#[derive(Clone, Copy, Default, Debug, PartialEq, Eq, Hash)]
317pub enum TeacherBackend {
318 Sonnet46,
319 #[default]
320 Sonnet45,
321 Gpt52,
322}
323
324impl std::fmt::Display for TeacherBackend {
325 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
326 match self {
327 TeacherBackend::Sonnet46 => write!(f, "sonnet46"),
328 TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
329 TeacherBackend::Gpt52 => write!(f, "gpt52"),
330 }
331 }
332}
333
334impl std::str::FromStr for TeacherBackend {
335 type Err = anyhow::Error;
336
337 fn from_str(s: &str) -> Result<Self, Self::Err> {
338 match s.to_lowercase().as_str() {
339 "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
340 "sonnet46" => Ok(TeacherBackend::Sonnet46),
341 "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
342 "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
343 _ => anyhow::bail!(
344 "unknown teacher backend `{s}`. Valid options: sonnet45, sonnet46, gpt52"
345 ),
346 }
347 }
348}
349
350impl TeacherBackend {
351 pub fn model_name(&self) -> &'static str {
352 match self {
353 TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
354 TeacherBackend::Sonnet46 => "claude-sonnet-4-6",
355 TeacherBackend::Gpt52 => "gpt-5.2",
356 }
357 }
358}
359
360#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
361enum PredictionProvider {
362 Mercury,
363 Zeta1,
364 Zeta2(ZetaFormat),
365 Baseten(ZetaFormat),
366 Teacher(TeacherBackend),
367 TeacherMultiRegion(TeacherBackend),
368 TeacherNonBatching(TeacherBackend),
369 TeacherMultiRegionNonBatching(TeacherBackend),
370 Repair,
371}
372
373impl Default for PredictionProvider {
374 fn default() -> Self {
375 PredictionProvider::Zeta2(ZetaFormat::default())
376 }
377}
378
379impl std::fmt::Display for PredictionProvider {
380 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
381 match self {
382 PredictionProvider::Mercury => write!(f, "mercury"),
383 PredictionProvider::Zeta1 => write!(f, "zeta1"),
384 PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"),
385 PredictionProvider::Baseten(format) => write!(f, "baseten:{format}"),
386 PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
387 PredictionProvider::TeacherMultiRegion(backend) => {
388 write!(f, "teacher-multi-region:{backend}")
389 }
390 PredictionProvider::TeacherNonBatching(backend) => {
391 write!(f, "teacher-non-batching:{backend}")
392 }
393 PredictionProvider::TeacherMultiRegionNonBatching(backend) => {
394 write!(f, "teacher-multi-region-non-batching:{backend}")
395 }
396 PredictionProvider::Repair => write!(f, "repair"),
397 }
398 }
399}
400
401impl std::str::FromStr for PredictionProvider {
402 type Err = anyhow::Error;
403
404 fn from_str(s: &str) -> Result<Self, Self::Err> {
405 let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
406
407 let provider_lower = provider.to_lowercase();
408 match provider_lower.as_str() {
409 "mercury" => Ok(PredictionProvider::Mercury),
410 "zeta1" => Ok(PredictionProvider::Zeta1),
411 "zeta2" => {
412 let format = arg.map(ZetaFormat::parse).transpose()?.unwrap_or_default();
413 Ok(PredictionProvider::Zeta2(format))
414 }
415 "teacher" => {
416 let backend = arg
417 .map(|a| a.parse())
418 .transpose()?
419 .unwrap_or(TeacherBackend::default());
420 Ok(PredictionProvider::Teacher(backend))
421 }
422 "teacher-multi-region" | "teacher_multi_region" => {
423 let backend = arg
424 .map(|a| a.parse())
425 .transpose()?
426 .unwrap_or(TeacherBackend::default());
427 Ok(PredictionProvider::TeacherMultiRegion(backend))
428 }
429 "teacher-non-batching" | "teacher_non_batching" => {
430 let backend = arg
431 .map(|a| a.parse())
432 .transpose()?
433 .unwrap_or(TeacherBackend::default());
434 Ok(PredictionProvider::TeacherNonBatching(backend))
435 }
436 "teacher-multi-region-non-batching" | "teacher_multi_region_non_batching" => {
437 let backend = arg
438 .map(|a| a.parse())
439 .transpose()?
440 .unwrap_or(TeacherBackend::default());
441 Ok(PredictionProvider::TeacherMultiRegionNonBatching(backend))
442 }
443 "repair" => Ok(PredictionProvider::Repair),
444 "baseten" => {
445 let format = arg
446 .map(ZetaFormat::parse)
447 .transpose()?
448 .unwrap_or(ZetaFormat::default());
449 Ok(PredictionProvider::Baseten(format))
450 }
451 _ => {
452 anyhow::bail!(
453 "unknown provider `{provider}`. Valid options: mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-multi-region, teacher-multi-region:<backend>, teacher-non-batching, teacher-multi-region-non-batching, repair\n\
454 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
455 For teacher providers, you can specify a backend like `teacher:sonnet46`, `teacher-multi-region:sonnet46`, `teacher-multi-region-non-batching:sonnet46`, or `teacher:gpt52`.\n\
456 Available zeta versions:\n{}",
457 ZetaFormat::options_as_string()
458 )
459 }
460 }
461 }
462}
463
464impl Serialize for PredictionProvider {
465 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
466 where
467 S: Serializer,
468 {
469 serializer.serialize_str(&self.to_string())
470 }
471}
472
473impl<'de> Deserialize<'de> for PredictionProvider {
474 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
475 where
476 D: Deserializer<'de>,
477 {
478 let s = String::deserialize(deserializer)?;
479 s.parse().map_err(serde::de::Error::custom)
480 }
481}
482
483#[derive(Debug, Args, Clone)]
484struct SynthesizeArgs {
485 /// Repository URLs (git@github.com:owner/repo or https://...)
486 #[clap(long, required = true, num_args = 1..)]
487 repos: Vec<String>,
488
489 /// Number of examples to generate per repository
490 #[clap(long, default_value_t = 5)]
491 count: usize,
492
493 /// Maximum commits to scan per repository before giving up
494 #[clap(long, default_value_t = 100)]
495 max_commits: usize,
496
497 /// Ignore state file and reprocess all commits
498 #[clap(long)]
499 fresh: bool,
500}
501
502#[derive(Debug, Args, Clone)]
503struct ImportBatchArgs {
504 /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
505 #[clap(long, required = true, num_args = 1..)]
506 batch_ids: Vec<String>,
507 /// Which provider's batches to import (anthropic or openai)
508 #[clap(long, default_value = "anthropic")]
509 provider: BatchProvider,
510}
511
512#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
513enum BatchProvider {
514 Anthropic,
515 Openai,
516}
517
518#[cfg(test)]
519mod tests {
520 use super::*;
521
522 #[test]
523 fn prediction_provider_multi_region_non_batched_round_trips_to_primary_spelling() {
524 let provider: PredictionProvider = "teacher-multi-region-non-batching:sonnet46"
525 .parse()
526 .unwrap();
527 assert_eq!(
528 provider,
529 PredictionProvider::TeacherMultiRegionNonBatching(TeacherBackend::Sonnet46)
530 );
531 assert_eq!(
532 provider.to_string(),
533 "teacher-multi-region-non-batching:sonnet46"
534 );
535 }
536
537 #[test]
538 fn prediction_provider_multi_region_non_batched_alias_round_trips_to_primary_spelling() {
539 let provider: PredictionProvider =
540 "teacher_multi_region_non_batching:gpt52".parse().unwrap();
541 assert_eq!(
542 provider,
543 PredictionProvider::TeacherMultiRegionNonBatching(TeacherBackend::Gpt52)
544 );
545 assert_eq!(
546 provider.to_string(),
547 "teacher-multi-region-non-batching:gpt52"
548 );
549 }
550}
551
552impl EpArgs {
553 fn output_path(&self) -> Option<PathBuf> {
554 if self.in_place {
555 if self.inputs.len() == 1 {
556 self.inputs.first().cloned()
557 } else {
558 panic!("--in-place requires exactly one input file")
559 }
560 } else {
561 self.output.clone()
562 }
563 }
564}
565
566/// Minimum Zed version required for Snowflake queries.
567/// This version introduced the current request schema with predicted edits in the edit
568/// history, and open source repos distinguished.
569const MIN_CAPTURE_VERSION: pull_examples::MinCaptureVersion = pull_examples::MinCaptureVersion {
570 minor: 224,
571 patch: 1,
572};
573
574fn deduplicate_examples(examples: &mut Vec<Example>, max_per_cluster: usize) {
575 let total_before_exact = examples.len();
576 let mut seen_positions = HashSet::default();
577 examples.retain(|example| seen_positions.insert(example.spec.cursor_position.clone()));
578 log::info!(
579 "exact duplicate filter: {total_before_exact} examples → {} examples ({} removed)",
580 examples.len(),
581 total_before_exact - examples.len(),
582 );
583
584 const JACCARD_THRESHOLD: f64 = 0.5;
585 const NUM_HASHES: usize = 128;
586 const TOKEN_NGRAM_SIZE: usize = 5;
587
588 let (num_bands, band_width) = calculate_minhash_params(JACCARD_THRESHOLD, NUM_HASHES);
589 let num_hashes = num_bands * band_width;
590 let minhasher = MinHasher32::new(num_hashes);
591 let mut index: MinHashIndex<u32, usize> =
592 MinHashIndex::new(num_bands, band_width, JACCARD_THRESHOLD);
593
594 let signatures: Vec<Vec<u32>> = examples
595 .iter()
596 .map(|example| {
597 let shingles = code_token_ngrams(&example.spec.cursor_position, TOKEN_NGRAM_SIZE);
598 minhasher.create_signature(shingles.iter())
599 })
600 .collect();
601
602 for (id, signature) in signatures.iter().enumerate() {
603 index.insert(id, signature.clone());
604 }
605
606 // Build clusters via union-find on LSH candidate pairs.
607 let mut parent: Vec<usize> = (0..examples.len()).collect();
608
609 fn find(parent: &mut Vec<usize>, mut x: usize) -> usize {
610 while parent[x] != x {
611 parent[x] = parent[parent[x]];
612 x = parent[x];
613 }
614 x
615 }
616
617 for (id, signature) in signatures.iter().enumerate() {
618 for candidate in index.query_owned(signature) {
619 let (a, b) = (find(&mut parent, id), find(&mut parent, candidate));
620 if a != b {
621 parent[a] = b;
622 }
623 }
624 }
625
626 let mut clusters: HashMap<usize, Vec<usize>> = HashMap::default();
627 for id in 0..examples.len() {
628 clusters.entry(find(&mut parent, id)).or_default().push(id);
629 }
630
631 let mut keep: HashSet<usize> = HashSet::default();
632 for members in clusters.values() {
633 let selected = greedy_max_min_diverse(members, &signatures, max_per_cluster);
634 keep.extend(selected);
635 }
636
637 let total = examples.len();
638 let mut kept_indices: Vec<usize> = keep.into_iter().collect();
639 kept_indices.sort();
640
641 let mut retained = Vec::with_capacity(kept_indices.len());
642 for index in kept_indices.into_iter().rev() {
643 retained.push(examples.swap_remove(index));
644 }
645 retained.reverse();
646
647 *examples = retained;
648 log::info!(
649 "near-duplicate filter: {total} examples → {} examples ({} removed)",
650 examples.len(),
651 total - examples.len(),
652 );
653}
654
655fn greedy_max_min_diverse(members: &[usize], signatures: &[Vec<u32>], k: usize) -> Vec<usize> {
656 if members.len() <= k {
657 return members.to_vec();
658 }
659
660 let mut selected = vec![members[0]];
661 let mut min_dist: HashMap<usize, f64> = HashMap::default();
662 for &member in &members[1..] {
663 let dist = 1.0 - compute_minhash_similarity(&signatures[selected[0]], &signatures[member]);
664 min_dist.insert(member, dist);
665 }
666
667 while selected.len() < k {
668 let &best = min_dist
669 .iter()
670 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
671 .map(|(id, _)| id)
672 .expect("min_dist should not be empty when selected.len() < k");
673 selected.push(best);
674 min_dist.remove(&best);
675
676 let best_sig = &signatures[best];
677 for (member, current_min) in min_dist.iter_mut() {
678 let dist = 1.0 - compute_minhash_similarity(best_sig, &signatures[*member]);
679 if dist < *current_min {
680 *current_min = dist;
681 }
682 }
683 }
684
685 selected
686}
687
688fn code_token_ngrams(code: &str, ngram_size: usize) -> Vec<String> {
689 let tokens: Vec<&str> = word_diff::tokenize(code)
690 .into_iter()
691 .filter(|t| !t.trim().is_empty())
692 .collect();
693
694 if tokens.len() < ngram_size {
695 return vec![tokens.join("\0")];
696 }
697
698 tokens
699 .windows(ngram_size)
700 .map(|window| window.join("\0"))
701 .collect()
702}
703
704async fn load_examples(
705 http_client: Arc<dyn http_client::HttpClient>,
706 args: &EpArgs,
707 output_path: Option<&PathBuf>,
708 background_executor: BackgroundExecutor,
709) -> anyhow::Result<Vec<Example>> {
710 let mut captured_after_timestamps = Vec::new();
711 let mut rejected_after_timestamps = Vec::new();
712 let mut requested_after_timestamps = Vec::new();
713 let mut settled_after_timestamps = Vec::new();
714 let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
715 Vec::new();
716 let mut file_inputs = Vec::new();
717
718 for input in &args.inputs {
719 let input_string = input.to_string_lossy();
720 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
721 captured_after_timestamps.push(timestamp.to_string());
722 } else if let Some(timestamp) =
723 pull_examples::parse_rejected_after_input(input_string.as_ref())
724 {
725 rejected_after_timestamps.push(timestamp.to_string());
726 } else if let Some(timestamp) =
727 pull_examples::parse_requested_after_input(input_string.as_ref())
728 {
729 requested_after_timestamps.push(timestamp.to_string());
730 } else if let Some(timestamp) = parse_settled_after_input(input_string.as_ref()) {
731 settled_after_timestamps.push(timestamp.to_string());
732 } else if let Some((timestamp, rating_filter)) =
733 pull_examples::parse_rated_after_input(input_string.as_ref())
734 {
735 rated_after_inputs.push((timestamp.to_string(), rating_filter));
736 } else {
737 file_inputs.push(input.clone());
738 }
739 }
740
741 let mut examples = read_example_files(&file_inputs);
742
743 // Apply offset to file examples first, then pass remaining offset to Snowflake.
744 let file_example_count = examples.len();
745 let remaining_offset = if let Some(offset) = args.offset {
746 if offset >= file_example_count {
747 examples.clear();
748 offset - file_example_count
749 } else {
750 examples.splice(0..offset, []);
751 0
752 }
753 } else {
754 0
755 };
756
757 Progress::global().set_total_examples(examples.len());
758
759 let remaining_limit_for_snowflake =
760 args.limit.map(|limit| limit.saturating_sub(examples.len()));
761
762 if let Some(0) = remaining_limit_for_snowflake {
763 log::info!(
764 "skipping Snowflake inputs because --limit is already satisfied by example files"
765 );
766 } else {
767 let max_rows_per_timestamp = remaining_limit_for_snowflake;
768
769 if !rejected_after_timestamps.is_empty() {
770 rejected_after_timestamps.sort();
771
772 let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
773 http_client.clone(),
774 &rejected_after_timestamps,
775 max_rows_per_timestamp,
776 remaining_offset,
777 background_executor.clone(),
778 Some(MIN_CAPTURE_VERSION),
779 )
780 .await?;
781 examples.append(&mut rejected_examples);
782 }
783
784 if !requested_after_timestamps.is_empty() {
785 requested_after_timestamps.sort();
786
787 let mut requested_examples = pull_examples::fetch_requested_examples_after(
788 http_client.clone(),
789 &requested_after_timestamps,
790 max_rows_per_timestamp,
791 remaining_offset,
792 background_executor.clone(),
793 Some(MIN_CAPTURE_VERSION),
794 )
795 .await?;
796 examples.append(&mut requested_examples);
797 }
798
799 if !captured_after_timestamps.is_empty() {
800 captured_after_timestamps.sort();
801
802 let mut captured_examples = pull_examples::fetch_captured_examples_after(
803 http_client.clone(),
804 &captured_after_timestamps,
805 max_rows_per_timestamp,
806 remaining_offset,
807 background_executor.clone(),
808 Some(MIN_CAPTURE_VERSION),
809 )
810 .await?;
811 examples.append(&mut captured_examples);
812 }
813
814 if !settled_after_timestamps.is_empty() {
815 settled_after_timestamps.sort();
816
817 let mut settled_examples = fetch_settled_examples_after(
818 http_client.clone(),
819 &settled_after_timestamps,
820 max_rows_per_timestamp,
821 remaining_offset,
822 background_executor.clone(),
823 Some(MIN_CAPTURE_VERSION),
824 )
825 .await?;
826 examples.append(&mut settled_examples);
827 }
828
829 if !rated_after_inputs.is_empty() {
830 rated_after_inputs.sort();
831
832 let mut rated_examples = pull_examples::fetch_rated_examples_after(
833 http_client,
834 &rated_after_inputs,
835 max_rows_per_timestamp,
836 remaining_offset,
837 background_executor,
838 Some(MIN_CAPTURE_VERSION),
839 )
840 .await?;
841 examples.append(&mut rated_examples);
842 }
843 }
844
845 crate::example::sort_examples_by_repo_and_rev(&mut examples);
846
847 if let Some(name_filter) = &args.name {
848 examples.retain(|example| example.spec.name.contains(name_filter));
849 }
850 if let Some(repo_filter) = &args.repo {
851 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
852 }
853
854 // Skip resume logic for --in-place since input and output are the same file,
855 // which would incorrectly treat all input examples as already processed.
856 if !args.in_place {
857 if let Some(path) = output_path
858 && let Some(command) = &args.command
859 {
860 resume_from_output(path, &mut examples, command);
861 }
862 }
863
864 if let Some(max_duplicates) = args.max_duplicates {
865 deduplicate_examples(&mut examples, max_duplicates);
866 }
867
868 if let Some(limit) = args.limit {
869 examples.truncate(limit);
870 }
871
872 let progress = Progress::global();
873 progress.set_total_examples(examples.len());
874 progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
875
876 Ok(examples)
877}
878
879fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
880 let mut hasher = collections::FxHasher::default();
881 spec.hash(&mut hasher);
882 hasher.finish()
883}
884
885fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>, command: &Command) {
886 let file = match File::open(path) {
887 Ok(f) => f,
888 Err(_) => return,
889 };
890
891 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
892
893 let reader = BufReader::new(file);
894 let mut kept_lines = Vec::new();
895 let mut kept_hashes = HashSet::default();
896
897 for line in reader.lines() {
898 let line = match line {
899 Ok(l) => l,
900 Err(_) => continue,
901 };
902
903 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
904 let hash = spec_hash(&output_example.spec);
905 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
906 let is_complete = match command {
907 Command::Qa(_) => output_example
908 .qa
909 .first()
910 .and_then(|q| q.as_ref())
911 .and_then(|q| q.confidence)
912 .is_some(),
913 Command::Repair(_) => output_example.predictions.iter().any(|p| {
914 p.provider == PredictionProvider::Repair && p.actual_patch.is_some()
915 }),
916 _ => true,
917 };
918 if is_complete {
919 kept_hashes.insert(hash);
920 kept_lines.push(line);
921 }
922 }
923 }
924 }
925
926 let total = examples.len();
927 let already_processed = kept_hashes.len();
928
929 eprintln!(
930 "Resuming: {}/{} examples already processed",
931 already_processed, total
932 );
933
934 let file = OpenOptions::new()
935 .write(true)
936 .truncate(true)
937 .open(path)
938 .expect("Failed to open output file for rewriting");
939 let mut writer = BufWriter::new(file);
940 for line in &kept_lines {
941 writeln!(writer, "{}", line).expect("Failed to write to output file");
942 }
943 writer.flush().expect("Failed to flush output file");
944
945 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
946}
947
948fn main() {
949 let args = EpArgs::parse();
950
951 if args.printenv {
952 ::util::shell_env::print_env();
953 return;
954 }
955
956 let output = args.output_path();
957
958 if args.markdown && output.is_none() {
959 eprintln!("--markdown requires -o to specify the output directory");
960 std::process::exit(1);
961 }
962
963 let command = match &args.command {
964 Some(cmd) => cmd.clone(),
965 None => {
966 EpArgs::command().print_help().unwrap();
967 return;
968 }
969 };
970
971 match &command {
972 Command::ImportBatch(import_args) => {
973 smol::block_on(async {
974 match import_args.provider {
975 BatchProvider::Anthropic => {
976 let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
977 .expect("Failed to create Anthropic client");
978 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
979 eprintln!("Error importing Anthropic batches: {:?}", e);
980 std::process::exit(1);
981 }
982 }
983 BatchProvider::Openai => {
984 let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
985 .expect("Failed to create OpenAI client");
986 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
987 eprintln!("Error importing OpenAI batches: {:?}", e);
988 std::process::exit(1);
989 }
990 }
991 }
992 println!(
993 "Successfully imported {} batch(es)",
994 import_args.batch_ids.len()
995 );
996 });
997 return;
998 }
999 Command::Clean => {
1000 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
1001 return;
1002 }
1003 Command::PrintZetaFormats => {
1004 use strum::IntoEnumIterator as _;
1005 for format in ZetaFormat::iter() {
1006 println!("{}", format.to_string().to_lowercase());
1007 }
1008 return;
1009 }
1010
1011 Command::Synthesize(synth_args) => {
1012 let output_dir = if let Some(output_dir) = args.output {
1013 output_dir
1014 } else {
1015 let default_output_dir = env::current_dir()
1016 .unwrap()
1017 .join("crates/edit_prediction_cli/evals-generated");
1018 if default_output_dir.parent().unwrap().exists() {
1019 std::fs::create_dir(&default_output_dir).ok();
1020 default_output_dir
1021 } else {
1022 panic!("output dir is required");
1023 }
1024 };
1025 let config = SynthesizeConfig {
1026 repo_urls: synth_args.repos.clone(),
1027 count: synth_args.count,
1028 max_commits: synth_args.max_commits,
1029 output_dir,
1030 fresh: synth_args.fresh,
1031 };
1032 smol::block_on(async {
1033 if let Err(e) = run_synthesize(config).await {
1034 eprintln!("Error: {:?}", e);
1035 std::process::exit(1);
1036 }
1037 });
1038 return;
1039 }
1040 Command::SplitCommit(split_commit_args) => {
1041 if let Err(error) = split_commit::run_split_commit(
1042 split_commit_args,
1043 &args.inputs,
1044 output.as_ref(),
1045 args.failed,
1046 ) {
1047 eprintln!("{error:#}");
1048 std::process::exit(1);
1049 }
1050 return;
1051 }
1052 Command::TruncatePatch(truncate_args) => {
1053 if let Err(error) =
1054 truncate_expected_patch::run_truncate_expected_patch(truncate_args, &args.inputs)
1055 {
1056 eprintln!("{error:#}");
1057 std::process::exit(1);
1058 }
1059 return;
1060 }
1061 Command::Split(split_args) => {
1062 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
1063 eprintln!("{error:#}");
1064 std::process::exit(1);
1065 }
1066 return;
1067 }
1068 Command::FilterLanguages(filter_args) => {
1069 if let Err(error) =
1070 run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
1071 {
1072 eprintln!("{error:#}");
1073 std::process::exit(1);
1074 }
1075 return;
1076 }
1077
1078 _ => {}
1079 }
1080
1081 let http_client = Arc::new(ReqwestClient::new());
1082 let app = gpui_platform::headless().with_http_client(http_client);
1083
1084 app.run(move |cx| {
1085 let app_state = Arc::new(headless::init(cx));
1086 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
1087
1088 cx.spawn(async move |cx| {
1089 let result = async {
1090 let examples = load_examples(
1091 app_state.client.http_client(),
1092 &args,
1093 output.as_ref(),
1094 cx.background_executor().clone(),
1095 )
1096 .await?;
1097
1098 match &command {
1099 Command::Predict(args) | Command::Score(args) => {
1100 predict::sync_batches(args.provider.as_ref()).await?;
1101 }
1102 Command::Eval(args) => {
1103 predict::sync_batches(args.predict.provider.as_ref()).await?;
1104 }
1105 Command::Qa(args) => {
1106 qa::sync_batches(args).await?;
1107 }
1108 Command::Repair(args) => {
1109 repair::sync_batches(args).await?;
1110 }
1111 _ => (),
1112 }
1113
1114 let failfast_on_single_example = examples.len() == 1;
1115
1116 // For --markdown mode, create the output directory if it doesn't exist
1117 if args.markdown {
1118 let dir = output.as_ref().expect("--markdown requires -o");
1119 if !dir.exists() {
1120 std::fs::create_dir_all(dir)
1121 .expect("Failed to create markdown output directory");
1122 }
1123 }
1124
1125 // Set up JSONL output writer (not used in markdown mode)
1126 let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
1127 let mut in_place_temp_path: Option<PathBuf> = None;
1128 if !args.markdown
1129 && let Some(output_path) = output.as_ref()
1130 {
1131 let write_path = if args.in_place {
1132 let temp = output_path.with_extension("jsonl.tmp");
1133 in_place_temp_path = Some(temp.clone());
1134 temp
1135 } else {
1136 output_path.clone()
1137 };
1138
1139 let file = OpenOptions::new()
1140 .create(true)
1141 .write(true)
1142 .truncate(args.in_place)
1143 .append(!args.in_place)
1144 .open(&write_path)
1145 .expect("Failed to open output file");
1146
1147 let mut writer = BufWriter::new(file);
1148 let (sender, mut receiver) = mpsc::unbounded::<String>();
1149 cx.background_spawn(async move {
1150 while let Some(line) = receiver.next().await {
1151 writeln!(writer, "{}", line).expect("Failed to write example");
1152 writer.flush().expect("Failed to flush output");
1153 }
1154 })
1155 .detach();
1156 output_sender = Some(sender);
1157 }
1158
1159 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
1160 let finished_examples = Mutex::new(Vec::new());
1161
1162 let mut tasks = Vec::new();
1163 for _ in 0..args.max_parallelism {
1164 tasks.push(async {
1165 loop {
1166 let Some(mut repo_examples) =
1167 grouped_examples.lock().unwrap().pop_front()
1168 else {
1169 break;
1170 };
1171 for example in &mut repo_examples {
1172 let example_progress =
1173 Progress::global().start_group(&example.spec.name);
1174
1175 let result = async {
1176 match &command {
1177 Command::Read(_) => {}
1178 Command::LoadProject => {
1179 run_load_project(
1180 example,
1181 app_state.clone(),
1182 &example_progress,
1183 cx.clone(),
1184 )
1185 .await?;
1186 }
1187 Command::Context => {
1188 run_context_retrieval(
1189 example,
1190 app_state.clone(),
1191 &example_progress,
1192 cx.clone(),
1193 )
1194 .await?;
1195 }
1196 Command::FormatPrompt(args) => {
1197 run_format_prompt(
1198 example,
1199 args,
1200 app_state.clone(),
1201 &example_progress,
1202 cx.clone(),
1203 )
1204 .await?;
1205 }
1206 Command::Predict(args) => {
1207 run_prediction(
1208 example,
1209 args,
1210 app_state.clone(),
1211 &example_progress,
1212 cx.clone(),
1213 )
1214 .await?;
1215 }
1216 Command::ParseOutput => {
1217 parse_output::run_parse_output(example)?;
1218 }
1219 Command::Distill => {
1220 run_distill(example).await?;
1221 }
1222 Command::Score(args) => {
1223 run_scoring(
1224 example,
1225 args,
1226 app_state.clone(),
1227 &example_progress,
1228 cx.clone(),
1229 )
1230 .await?;
1231 }
1232 Command::Eval(args) => {
1233 run_scoring(
1234 example,
1235 &args.predict,
1236 app_state.clone(),
1237 &example_progress,
1238 cx.clone(),
1239 )
1240 .await?;
1241 }
1242 Command::Qa(args) => {
1243 qa::run_qa(example, args, &example_progress).await?;
1244 }
1245 Command::Repair(args) => {
1246 repair::run_repair(example, args, &example_progress)
1247 .await?;
1248 }
1249 Command::Clean
1250 | Command::Synthesize(_)
1251 | Command::SplitCommit(_)
1252 | Command::Split(_)
1253 | Command::TruncatePatch(_)
1254 | Command::FilterLanguages(_)
1255 | Command::ImportBatch(_)
1256 | Command::PrintZetaFormats => {
1257 unreachable!()
1258 }
1259 }
1260 anyhow::Ok(())
1261 }
1262 .await;
1263
1264 let failed = if let Err(error) = result {
1265 handle_error(
1266 error,
1267 &args,
1268 &command,
1269 &app_state,
1270 failfast_on_single_example,
1271 &example,
1272 )
1273 .await;
1274 true
1275 } else {
1276 false
1277 };
1278
1279 let should_write = !failed || args.failed == FailedHandling::Keep;
1280 if should_write {
1281 if args.markdown {
1282 let markdown_dir =
1283 output.as_ref().expect("--markdown requires -o");
1284 let filename = format!("{}.md", example.spec.filename());
1285 let path = markdown_dir.join(&filename);
1286 let markdown = example.spec.to_markdown();
1287 std::fs::write(&path, &markdown)
1288 .expect("Failed to write markdown file");
1289 } else if let Some(ref mut sender) = output_sender.clone() {
1290 let line = serde_json::to_string(&example).unwrap();
1291 sender
1292 .send(line)
1293 .await
1294 .expect("Failed to send to output writer");
1295 } else if args.output.is_none()
1296 && !matches!(command, Command::Eval(_))
1297 {
1298 let line = serde_json::to_string(&example).unwrap();
1299 println!("{}", line);
1300 }
1301 }
1302 }
1303
1304 let project = repo_examples
1305 .iter()
1306 .find_map(|e| e.state.as_ref().map(|s| s.project.clone()));
1307
1308 if let Some(project) = project {
1309 let mut cx = cx.clone();
1310
1311 let shutdown_task: Task<()> =
1312 project.update(&mut cx, |project, cx| {
1313 let lsp_store = project.lsp_store();
1314 lsp_store.update(cx, |lsp_store, cx| {
1315 lsp_store.shutdown_all_language_servers(cx)
1316 })
1317 });
1318
1319 shutdown_task.await;
1320
1321 if let Some(ep_store) =
1322 cx.update(|cx| EditPredictionStore::try_global(cx))
1323 {
1324 ep_store.update(&mut cx, |store, _| {
1325 store.remove_project(&project);
1326 });
1327 }
1328 }
1329
1330 for example in &mut repo_examples {
1331 example.state.take();
1332 }
1333 finished_examples
1334 .lock()
1335 .unwrap()
1336 .extend_from_slice(&repo_examples);
1337 }
1338 });
1339 }
1340 futures::future::join_all(tasks).await;
1341
1342 Progress::global().finalize();
1343
1344 let is_markdown = args.markdown;
1345 let write_path = in_place_temp_path.as_ref().or(output.as_ref());
1346 match &command {
1347 Command::Predict(args) | Command::Score(args) => {
1348 predict::sync_batches(args.provider.as_ref()).await?;
1349 if args.wait {
1350 predict::wait_for_batches(args.provider.as_ref()).await?;
1351 let mut examples =
1352 std::mem::take(&mut *finished_examples.lock().unwrap());
1353 predict::reprocess_after_batch_wait(&mut examples, args).await?;
1354 rewrite_output(&examples, write_path, is_markdown)?;
1355 *finished_examples.lock().unwrap() = examples;
1356 }
1357 }
1358 Command::Eval(args) => {
1359 predict::sync_batches(args.predict.provider.as_ref()).await?;
1360 if args.predict.wait {
1361 predict::wait_for_batches(args.predict.provider.as_ref()).await?;
1362 let mut examples =
1363 std::mem::take(&mut *finished_examples.lock().unwrap());
1364 predict::reprocess_after_batch_wait(&mut examples, &args.predict)
1365 .await?;
1366 rewrite_output(&examples, write_path, is_markdown)?;
1367 *finished_examples.lock().unwrap() = examples;
1368 }
1369 }
1370 Command::Qa(args) => {
1371 qa::sync_batches(args).await?;
1372 }
1373 Command::Repair(args) => {
1374 repair::sync_batches(args).await?;
1375 if args.wait {
1376 repair::wait_for_batches(args).await?;
1377 let mut examples =
1378 std::mem::take(&mut *finished_examples.lock().unwrap());
1379 repair::reprocess_after_batch_wait(&mut examples, args).await?;
1380 rewrite_output(&examples, write_path, is_markdown)?;
1381 *finished_examples.lock().unwrap() = examples;
1382 }
1383 }
1384 _ => (),
1385 }
1386
1387 match &command {
1388 Command::Eval(args) => {
1389 let examples = finished_examples.lock().unwrap();
1390 score::print_report(&examples, args.verbose);
1391 if let Some(summary_path) = &args.summary_json {
1392 score::write_summary_json(&examples, summary_path)?;
1393 }
1394 }
1395 Command::Repair(args) => {
1396 let examples = finished_examples.lock().unwrap();
1397 repair::print_report(&examples, args.confidence_threshold);
1398 }
1399 _ => (),
1400 };
1401
1402 // For --in-place, atomically rename temp file to original
1403 if let Some(temp_path) = &in_place_temp_path {
1404 let final_path = output.as_ref().expect("in_place_temp_path requires output");
1405 std::fs::rename(temp_path, final_path)
1406 .expect("Failed to rename temp file to final output");
1407 }
1408
1409 anyhow::Ok(())
1410 }
1411 .await;
1412
1413 if let Err(e) = result {
1414 panic!("Fatal error: {:?}", e);
1415 }
1416
1417 let _ = cx.update(|cx| cx.quit());
1418 })
1419 .detach();
1420 });
1421}
1422
1423fn rewrite_output(
1424 examples: &[Example],
1425 output_path: Option<&PathBuf>,
1426 markdown: bool,
1427) -> anyhow::Result<()> {
1428 if markdown {
1429 let dir = output_path.context("--markdown requires -o")?;
1430 for example in examples {
1431 let filename = format!("{}.md", example.spec.filename());
1432 let path = dir.join(&filename);
1433 let markdown = example.spec.to_markdown();
1434 std::fs::write(&path, &markdown).context("Failed to write markdown file")?;
1435 }
1436 } else if let Some(path) = output_path {
1437 let file = OpenOptions::new()
1438 .create(true)
1439 .write(true)
1440 .truncate(true)
1441 .open(path)
1442 .context("Failed to open output file for rewriting")?;
1443 let mut writer = BufWriter::new(file);
1444 for example in examples {
1445 let line = serde_json::to_string(example)?;
1446 writeln!(writer, "{}", line)?;
1447 }
1448 writer.flush()?;
1449 } else {
1450 for example in examples {
1451 let line = serde_json::to_string(example)?;
1452 println!("{}", line);
1453 }
1454 }
1455 Ok(())
1456}
1457
1458async fn handle_error(
1459 error: anyhow::Error,
1460 args: &EpArgs,
1461 command: &Command,
1462 app_state: &Arc<headless::EpAppState>,
1463 failfast_on_single_example: bool,
1464 example: &Example,
1465) {
1466 Progress::global().increment_failed();
1467
1468 let msg;
1469 if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1470 let example_name = example.spec.filename();
1471
1472 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1473 app_state
1474 .fs
1475 .write(
1476 &failed_example_path,
1477 &serde_json::to_vec_pretty(&example).unwrap(),
1478 )
1479 .await
1480 .unwrap();
1481 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1482 app_state
1483 .fs
1484 .write(&err_path, format!("{error:?}").as_bytes())
1485 .await
1486 .unwrap();
1487
1488 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1489 let mut file = OpenOptions::new()
1490 .create(true)
1491 .append(true)
1492 .open(&failed_jsonl_path)
1493 .expect("Failed to open failed.jsonl");
1494 writeln!(file, "{}", serde_json::to_string(example).unwrap())
1495 .expect("Failed to write to failed.jsonl");
1496
1497 let cursor_path = match example.repo_name() {
1498 Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1499 Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1500 };
1501 msg = format!(
1502 indoc::indoc! {"
1503 While processing \"{}\":
1504
1505 \x1b[31m{:?}\x1b[0m
1506
1507 Example: \x1b[36m{}\x1b[0m
1508 Error file: \x1b[36m{}\x1b[0m
1509 Cursor file: \x1b[36m{}\x1b[0m
1510 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1511 "},
1512 example.spec.name,
1513 error,
1514 failed_example_path.display(),
1515 err_path.display(),
1516 cursor_path.display(),
1517 command,
1518 failed_example_path.display(),
1519 );
1520 } else {
1521 msg = format!(
1522 indoc::indoc! {"
1523 While processing \"{}\":
1524
1525 \x1b[31m{:?}\x1b[0m
1526 "},
1527 example.spec.name, error
1528 );
1529 }
1530
1531 if args.failfast || failfast_on_single_example {
1532 Progress::global().finalize();
1533 panic!("{}", msg);
1534 } else {
1535 log::error!("{}", msg);
1536 }
1537}