1mod anthropic_client;
2mod distill;
3mod example;
4mod filter_languages;
5mod format_prompt;
6mod git;
7mod headless;
8mod load_project;
9mod metrics;
10mod openai_client;
11mod parse_output;
12mod paths;
13mod predict;
14mod progress;
15mod prompt_assets;
16mod pull_examples;
17mod qa;
18mod reorder_patch;
19mod repair;
20mod retrieve_context;
21mod reversal_tracking;
22mod score;
23mod split_commit;
24mod split_dataset;
25
26mod synthesize;
27mod truncate_expected_patch;
28mod word_diff;
29use anyhow::Context as _;
30use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
31use collections::{HashMap, HashSet};
32use edit_prediction::EditPredictionStore;
33use futures::channel::mpsc;
34use futures::{SinkExt as _, StreamExt as _};
35use gaoya::minhash::{
36 MinHashIndex, MinHasher, MinHasher32, calculate_minhash_params, compute_minhash_similarity,
37};
38use gpui::{AppContext as _, BackgroundExecutor, Task};
39use zeta_prompt::ZetaFormat;
40
41use reqwest_client::ReqwestClient;
42use serde::{Deserialize, Deserializer, Serialize, Serializer};
43use std::env;
44use std::fmt::Display;
45use std::fs::{File, OpenOptions};
46use std::hash::{Hash, Hasher};
47use std::io::{BufRead, BufReader, BufWriter, Write};
48use std::sync::Mutex;
49use std::{path::PathBuf, sync::Arc};
50
51use crate::distill::run_distill;
52use crate::example::{Example, group_examples_by_repo, read_example_files};
53use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
54use crate::format_prompt::run_format_prompt;
55use crate::load_project::run_load_project;
56use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
57use crate::predict::run_prediction;
58use crate::progress::Progress;
59use crate::pull_examples::{fetch_settled_examples_after, parse_settled_after_input};
60use crate::retrieve_context::run_context_retrieval;
61use crate::score::run_scoring;
62use crate::split_commit::SplitCommitArgs;
63use crate::split_dataset::SplitArgs;
64use crate::synthesize::{SynthesizeConfig, run_synthesize};
65use crate::truncate_expected_patch::TruncatePatchArgs;
66
67#[derive(Parser, Debug)]
68#[command(name = "ep")]
69struct EpArgs {
70 #[arg(long, default_value_t = false)]
71 printenv: bool,
72 #[clap(long, default_value_t = 10, global = true)]
73 max_parallelism: usize,
74 /// The limit for the number of examples to process
75 /// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
76 #[clap(long, global = true)]
77 limit: Option<usize>,
78 #[clap(long, global = true)]
79 offset: Option<usize>,
80 /// Filter examples by name
81 #[clap(long, global = true)]
82 name: Option<String>,
83 /// Filter examples by repository
84 #[clap(long, global = true)]
85 repo: Option<String>,
86 /// Deduplicate by cursor position and keep at most this many examples per cluster
87 #[clap(long, global = true)]
88 max_duplicates: Option<usize>,
89 #[command(subcommand)]
90 command: Option<Command>,
91 /// Input file paths
92 #[clap(global = true)]
93 inputs: Vec<PathBuf>,
94 #[arg(long, short, global = true)]
95 output: Option<PathBuf>,
96 #[arg(long, short, global = true)]
97 in_place: bool,
98 #[arg(long, short, global = true)]
99 failfast: bool,
100 /// How to handle failed examples in output: keep them or skip them.
101 /// Failed examples are always logged to the run's failed directory.
102 #[arg(long, global = true, default_value = "keep")]
103 failed: FailedHandling,
104 /// Output as markdown files instead of JSONL. When set, -o specifies a directory
105 /// where one .md file per example will be written (named after each example).
106 #[arg(long, short, global = true)]
107 markdown: bool,
108}
109
110/// Controls whether failed examples are included in the main output.
111/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
112#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
113pub enum FailedHandling {
114 /// Include failed examples in the main output (default)
115 #[default]
116 Keep,
117 /// Exclude failed examples from the main output
118 Skip,
119 /// Skip writing files
120 SkipNoFiles,
121}
122
123const INPUTS_HELP: &str = r#"
124Inputs can be file paths or special specifiers:
125
126 path
127 Path to an example(s) file (.md, .json, or .jsonl)
128
129 captured-after:{timestamp}
130 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
131 These are examples captured via the "Capture Edit Prediction Example" action.
132
133 rejected-after:{timestamp}
134 Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
135 These are predictions that were shown to users but rejected (useful for DPO training).
136
137 settled-after:{timestamp}
138 Fetch settled stream examples from Snowflake after the given RFC3339 timestamp.
139 These are examples from the edit prediction settled stream.
140
141 rated-after:{timestamp}
142 Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
143 These are predictions that users explicitly rated as positive or negative via the
144 rate completions modal. Only zeta2 predictions are included.
145 - Positive ratings: output becomes expected_patches
146 - Negative ratings: output becomes rejected_patch
147
148 rated-positive-after:{timestamp}
149 Same as rated-after, but only fetches positively rated predictions.
150
151 rated-negative-after:{timestamp}
152 Same as rated-after, but only fetches negatively rated predictions.
153
154 Required environment variables to connect to Snowflake:
155 EP_SNOWFLAKE_API_KEY
156 EP_SNOWFLAKE_BASE_URL
157
158 Optional:
159 EP_SNOWFLAKE_ROLE
160
161Examples:
162
163 # Read examples from a file
164 ep read examples.jsonl -o output.jsonl
165
166 # Read captured examples after a timestamp
167 ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
168
169 # Read rejected predictions for DPO training
170 ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
171
172 # Read user-rated predictions
173 ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
174
175 # Read settled stream examples
176 ep read settled-after:2025-01-01T00:00:00Z -o settled.jsonl
177
178 # Read only positively rated predictions
179 ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
180
181 # Read only negatively rated predictions
182 ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
183
184 # Mix multiple input sources
185 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
186"#;
187
188#[derive(Subcommand, Debug, Clone)]
189enum Command {
190 /// Read examples from files or fetch from Snowflake, output as .jsonl
191 Read(ReadArgs),
192 /// Create git worktrees for each example and load file contents
193 LoadProject,
194 /// Retrieve context for input examples.
195 Context,
196 /// Generate a prompt string for a specific model
197 FormatPrompt(FormatPromptArgs),
198 /// Runs edit prediction
199 Predict(PredictArgs),
200 /// Parse model outputs (actual_output) into unified diffs (actual_patch).
201 /// Requires format-prompt to have been run first. Uses provider from prompt.
202 ParseOutput,
203 /// Computes a score based on actual and expected patches
204 Score(PredictArgs),
205 /// Prepares a distillation dataset by copying expected outputs to
206 /// predicted outputs and removing actual outputs and prompts.
207 Distill,
208 /// Print aggregated scores
209 Eval(EvalArgs),
210 /// Generate eval examples by analyzing git commits from a repository
211 Synthesize(SynthesizeArgs),
212 /// Remove git repositories and worktrees
213 Clean,
214 /// Generate an evaluation example by splitting a chronologically-ordered commit
215 SplitCommit(SplitCommitArgs),
216 /// Truncate expected patch by the given criteria
217 TruncatePatch(TruncatePatchArgs),
218 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
219 Split(SplitArgs),
220 /// Filter a JSONL dataset by programming language (based on cursor_path extension)
221 FilterLanguages(FilterLanguagesArgs),
222 /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
223 ImportBatch(ImportBatchArgs),
224 /// Assess the quality of predictions using LLM-as-a-judge
225 Qa(qa::QaArgs),
226 /// Repair predictions that received poor QA scores by generating improved predictions
227 Repair(repair::RepairArgs),
228 /// Print all valid zeta formats (lowercase, one per line)
229 PrintZetaFormats,
230}
231
232impl Display for Command {
233 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234 match self {
235 Command::Read(_) => write!(f, "read"),
236 Command::LoadProject => write!(f, "load-project"),
237 Command::Context => write!(f, "context"),
238 Command::FormatPrompt(args) => {
239 write!(f, "format-prompt --provider={}", args.provider)
240 }
241 Command::Predict(args) => match &args.provider {
242 Some(provider) => write!(f, "predict --provider={}", provider),
243 None => write!(f, "predict"),
244 },
245 Command::ParseOutput => write!(f, "parse-output"),
246 Command::Score(args) => match &args.provider {
247 Some(provider) => write!(f, "score --provider={}", provider),
248 None => write!(f, "score"),
249 },
250 Command::Distill => write!(f, "distill"),
251 Command::Eval(args) => match &args.predict.provider {
252 Some(provider) => write!(f, "eval --provider={}", provider),
253 None => write!(f, "eval"),
254 },
255 Command::Synthesize(args) => {
256 write!(f, "synthesize --repos {}", args.repos.join(" "))
257 }
258 Command::Clean => write!(f, "clean"),
259 Command::SplitCommit(_) => write!(f, "split-commit"),
260 Command::TruncatePatch(_) => write!(f, "truncate-patch"),
261 Command::Split(_) => write!(f, "split"),
262 Command::FilterLanguages(_) => write!(f, "filter-languages"),
263 Command::ImportBatch(args) => {
264 write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
265 }
266 Command::Qa(_) => {
267 write!(f, "qa")
268 }
269 Command::Repair(_) => {
270 write!(f, "repair")
271 }
272 Command::PrintZetaFormats => {
273 write!(f, "print-zeta-formats")
274 }
275 }
276 }
277}
278
279#[derive(Debug, Args, Clone)]
280#[command(after_help = INPUTS_HELP)]
281struct ReadArgs {}
282
283#[derive(Debug, Args, Clone)]
284struct FormatPromptArgs {
285 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
286 provider: PredictionProvider,
287}
288
289#[derive(Debug, Args, Clone)]
290struct PredictArgs {
291 #[clap(long, short('p'))]
292 provider: Option<PredictionProvider>,
293 #[clap(long, default_value_t = 1)]
294 repetitions: usize,
295 /// Only use cached responses, don't queue new requests for batching
296 #[clap(long)]
297 cache_only: bool,
298 /// Wait for all batches to complete before exiting (only applies to batched providers like teacher)
299 #[clap(long)]
300 wait: bool,
301}
302
303#[derive(Debug, Args, Clone)]
304struct EvalArgs {
305 #[clap(flatten)]
306 predict: PredictArgs,
307 /// Path to write summary scores as JSON
308 #[clap(long)]
309 summary_json: Option<PathBuf>,
310 /// Print all individual example lines (default: up to 20)
311 #[clap(long)]
312 verbose: bool,
313}
314
315#[derive(Clone, Copy, Default, Debug, PartialEq, Eq, Hash)]
316pub enum TeacherBackend {
317 Sonnet46,
318 #[default]
319 Sonnet45,
320 Gpt52,
321}
322
323impl std::fmt::Display for TeacherBackend {
324 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325 match self {
326 TeacherBackend::Sonnet46 => write!(f, "sonnet46"),
327 TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
328 TeacherBackend::Gpt52 => write!(f, "gpt52"),
329 }
330 }
331}
332
333impl std::str::FromStr for TeacherBackend {
334 type Err = anyhow::Error;
335
336 fn from_str(s: &str) -> Result<Self, Self::Err> {
337 match s.to_lowercase().as_str() {
338 "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
339 "sonnet46" => Ok(TeacherBackend::Sonnet46),
340 "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
341 "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
342 _ => anyhow::bail!(
343 "unknown teacher backend `{s}`. Valid options: sonnet45, sonnet46, gpt52"
344 ),
345 }
346 }
347}
348
349impl TeacherBackend {
350 pub fn model_name(&self) -> &'static str {
351 match self {
352 TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
353 TeacherBackend::Sonnet46 => "claude-sonnet-4-6",
354 TeacherBackend::Gpt52 => "gpt-5.2",
355 }
356 }
357}
358
359#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
360enum PredictionProvider {
361 Mercury,
362 Zeta1,
363 Zeta2(ZetaFormat),
364 Baseten(ZetaFormat),
365 Teacher(TeacherBackend),
366 TeacherMultiRegion(TeacherBackend),
367 TeacherNonBatching(TeacherBackend),
368 TeacherMultiRegionNonBatching(TeacherBackend),
369 Repair,
370}
371
372impl Default for PredictionProvider {
373 fn default() -> Self {
374 PredictionProvider::Zeta2(ZetaFormat::default())
375 }
376}
377
378impl std::fmt::Display for PredictionProvider {
379 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
380 match self {
381 PredictionProvider::Mercury => write!(f, "mercury"),
382 PredictionProvider::Zeta1 => write!(f, "zeta1"),
383 PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"),
384 PredictionProvider::Baseten(format) => write!(f, "baseten:{format}"),
385 PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
386 PredictionProvider::TeacherMultiRegion(backend) => {
387 write!(f, "teacher-multi-region:{backend}")
388 }
389 PredictionProvider::TeacherNonBatching(backend) => {
390 write!(f, "teacher-non-batching:{backend}")
391 }
392 PredictionProvider::TeacherMultiRegionNonBatching(backend) => {
393 write!(f, "teacher-multi-region-non-batching:{backend}")
394 }
395 PredictionProvider::Repair => write!(f, "repair"),
396 }
397 }
398}
399
400impl std::str::FromStr for PredictionProvider {
401 type Err = anyhow::Error;
402
403 fn from_str(s: &str) -> Result<Self, Self::Err> {
404 let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
405
406 let provider_lower = provider.to_lowercase();
407 match provider_lower.as_str() {
408 "mercury" => Ok(PredictionProvider::Mercury),
409 "zeta1" => Ok(PredictionProvider::Zeta1),
410 "zeta2" => {
411 let format = arg.map(ZetaFormat::parse).transpose()?.unwrap_or_default();
412 Ok(PredictionProvider::Zeta2(format))
413 }
414 "teacher" => {
415 let backend = arg
416 .map(|a| a.parse())
417 .transpose()?
418 .unwrap_or(TeacherBackend::default());
419 Ok(PredictionProvider::Teacher(backend))
420 }
421 "teacher-multi-region" | "teacher_multi_region" => {
422 let backend = arg
423 .map(|a| a.parse())
424 .transpose()?
425 .unwrap_or(TeacherBackend::default());
426 Ok(PredictionProvider::TeacherMultiRegion(backend))
427 }
428 "teacher-non-batching" | "teacher_non_batching" => {
429 let backend = arg
430 .map(|a| a.parse())
431 .transpose()?
432 .unwrap_or(TeacherBackend::default());
433 Ok(PredictionProvider::TeacherNonBatching(backend))
434 }
435 "teacher-multi-region-non-batching" | "teacher_multi_region_non_batching" => {
436 let backend = arg
437 .map(|a| a.parse())
438 .transpose()?
439 .unwrap_or(TeacherBackend::default());
440 Ok(PredictionProvider::TeacherMultiRegionNonBatching(backend))
441 }
442 "repair" => Ok(PredictionProvider::Repair),
443 "baseten" => {
444 let format = arg
445 .map(ZetaFormat::parse)
446 .transpose()?
447 .unwrap_or(ZetaFormat::default());
448 Ok(PredictionProvider::Baseten(format))
449 }
450 _ => {
451 anyhow::bail!(
452 "unknown provider `{provider}`. Valid options: mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-multi-region, teacher-multi-region:<backend>, teacher-non-batching, teacher-multi-region-non-batching, repair\n\
453 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
454 For teacher providers, you can specify a backend like `teacher:sonnet46`, `teacher-multi-region:sonnet46`, `teacher-multi-region-non-batching:sonnet46`, or `teacher:gpt52`.\n\
455 Available zeta versions:\n{}",
456 ZetaFormat::options_as_string()
457 )
458 }
459 }
460 }
461}
462
463impl Serialize for PredictionProvider {
464 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
465 where
466 S: Serializer,
467 {
468 serializer.serialize_str(&self.to_string())
469 }
470}
471
472impl<'de> Deserialize<'de> for PredictionProvider {
473 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
474 where
475 D: Deserializer<'de>,
476 {
477 let s = String::deserialize(deserializer)?;
478 s.parse().map_err(serde::de::Error::custom)
479 }
480}
481
482#[derive(Debug, Args, Clone)]
483struct SynthesizeArgs {
484 /// Repository URLs (git@github.com:owner/repo or https://...)
485 #[clap(long, required = true, num_args = 1..)]
486 repos: Vec<String>,
487
488 /// Number of examples to generate per repository
489 #[clap(long, default_value_t = 5)]
490 count: usize,
491
492 /// Maximum commits to scan per repository before giving up
493 #[clap(long, default_value_t = 100)]
494 max_commits: usize,
495
496 /// Ignore state file and reprocess all commits
497 #[clap(long)]
498 fresh: bool,
499}
500
501#[derive(Debug, Args, Clone)]
502struct ImportBatchArgs {
503 /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
504 #[clap(long, required = true, num_args = 1..)]
505 batch_ids: Vec<String>,
506 /// Which provider's batches to import (anthropic or openai)
507 #[clap(long, default_value = "anthropic")]
508 provider: BatchProvider,
509}
510
511#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
512enum BatchProvider {
513 Anthropic,
514 Openai,
515}
516
517#[cfg(test)]
518mod tests {
519 use super::*;
520
521 #[test]
522 fn prediction_provider_multi_region_non_batched_round_trips_to_primary_spelling() {
523 let provider: PredictionProvider = "teacher-multi-region-non-batching:sonnet46"
524 .parse()
525 .unwrap();
526 assert_eq!(
527 provider,
528 PredictionProvider::TeacherMultiRegionNonBatching(TeacherBackend::Sonnet46)
529 );
530 assert_eq!(
531 provider.to_string(),
532 "teacher-multi-region-non-batching:sonnet46"
533 );
534 }
535
536 #[test]
537 fn prediction_provider_multi_region_non_batched_alias_round_trips_to_primary_spelling() {
538 let provider: PredictionProvider =
539 "teacher_multi_region_non_batching:gpt52".parse().unwrap();
540 assert_eq!(
541 provider,
542 PredictionProvider::TeacherMultiRegionNonBatching(TeacherBackend::Gpt52)
543 );
544 assert_eq!(
545 provider.to_string(),
546 "teacher-multi-region-non-batching:gpt52"
547 );
548 }
549}
550
551impl EpArgs {
552 fn output_path(&self) -> Option<PathBuf> {
553 if self.in_place {
554 if self.inputs.len() == 1 {
555 self.inputs.first().cloned()
556 } else {
557 panic!("--in-place requires exactly one input file")
558 }
559 } else {
560 self.output.clone()
561 }
562 }
563}
564
565/// Minimum Zed version required for Snowflake queries.
566/// This version introduced the current request schema with predicted edits in the edit
567/// history, and open source repos distinguished.
568const MIN_CAPTURE_VERSION: pull_examples::MinCaptureVersion = pull_examples::MinCaptureVersion {
569 minor: 224,
570 patch: 1,
571};
572
573fn deduplicate_examples(examples: &mut Vec<Example>, max_per_cluster: usize) {
574 let total_before_exact = examples.len();
575 let mut seen_positions = HashSet::default();
576 examples.retain(|example| seen_positions.insert(example.spec.cursor_position.clone()));
577 log::info!(
578 "exact duplicate filter: {total_before_exact} examples → {} examples ({} removed)",
579 examples.len(),
580 total_before_exact - examples.len(),
581 );
582
583 const JACCARD_THRESHOLD: f64 = 0.5;
584 const NUM_HASHES: usize = 128;
585 const TOKEN_NGRAM_SIZE: usize = 5;
586
587 let (num_bands, band_width) = calculate_minhash_params(JACCARD_THRESHOLD, NUM_HASHES);
588 let num_hashes = num_bands * band_width;
589 let minhasher = MinHasher32::new(num_hashes);
590 let mut index: MinHashIndex<u32, usize> =
591 MinHashIndex::new(num_bands, band_width, JACCARD_THRESHOLD);
592
593 let signatures: Vec<Vec<u32>> = examples
594 .iter()
595 .map(|example| {
596 let shingles = code_token_ngrams(&example.spec.cursor_position, TOKEN_NGRAM_SIZE);
597 minhasher.create_signature(shingles.iter())
598 })
599 .collect();
600
601 for (id, signature) in signatures.iter().enumerate() {
602 index.insert(id, signature.clone());
603 }
604
605 // Build clusters via union-find on LSH candidate pairs.
606 let mut parent: Vec<usize> = (0..examples.len()).collect();
607
608 fn find(parent: &mut Vec<usize>, mut x: usize) -> usize {
609 while parent[x] != x {
610 parent[x] = parent[parent[x]];
611 x = parent[x];
612 }
613 x
614 }
615
616 for (id, signature) in signatures.iter().enumerate() {
617 for candidate in index.query_owned(signature) {
618 let (a, b) = (find(&mut parent, id), find(&mut parent, candidate));
619 if a != b {
620 parent[a] = b;
621 }
622 }
623 }
624
625 let mut clusters: HashMap<usize, Vec<usize>> = HashMap::default();
626 for id in 0..examples.len() {
627 clusters.entry(find(&mut parent, id)).or_default().push(id);
628 }
629
630 let mut keep: HashSet<usize> = HashSet::default();
631 for members in clusters.values() {
632 let selected = greedy_max_min_diverse(members, &signatures, max_per_cluster);
633 keep.extend(selected);
634 }
635
636 let total = examples.len();
637 let mut kept_indices: Vec<usize> = keep.into_iter().collect();
638 kept_indices.sort();
639
640 let mut retained = Vec::with_capacity(kept_indices.len());
641 for index in kept_indices.into_iter().rev() {
642 retained.push(examples.swap_remove(index));
643 }
644 retained.reverse();
645
646 *examples = retained;
647 log::info!(
648 "near-duplicate filter: {total} examples → {} examples ({} removed)",
649 examples.len(),
650 total - examples.len(),
651 );
652}
653
654fn greedy_max_min_diverse(members: &[usize], signatures: &[Vec<u32>], k: usize) -> Vec<usize> {
655 if members.len() <= k {
656 return members.to_vec();
657 }
658
659 let mut selected = vec![members[0]];
660 let mut min_dist: HashMap<usize, f64> = HashMap::default();
661 for &member in &members[1..] {
662 let dist = 1.0 - compute_minhash_similarity(&signatures[selected[0]], &signatures[member]);
663 min_dist.insert(member, dist);
664 }
665
666 while selected.len() < k {
667 let &best = min_dist
668 .iter()
669 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
670 .map(|(id, _)| id)
671 .expect("min_dist should not be empty when selected.len() < k");
672 selected.push(best);
673 min_dist.remove(&best);
674
675 let best_sig = &signatures[best];
676 for (member, current_min) in min_dist.iter_mut() {
677 let dist = 1.0 - compute_minhash_similarity(best_sig, &signatures[*member]);
678 if dist < *current_min {
679 *current_min = dist;
680 }
681 }
682 }
683
684 selected
685}
686
687fn code_token_ngrams(code: &str, ngram_size: usize) -> Vec<String> {
688 let tokens: Vec<&str> = word_diff::tokenize(code)
689 .into_iter()
690 .filter(|t| !t.trim().is_empty())
691 .collect();
692
693 if tokens.len() < ngram_size {
694 return vec![tokens.join("\0")];
695 }
696
697 tokens
698 .windows(ngram_size)
699 .map(|window| window.join("\0"))
700 .collect()
701}
702
703async fn load_examples(
704 http_client: Arc<dyn http_client::HttpClient>,
705 args: &EpArgs,
706 output_path: Option<&PathBuf>,
707 background_executor: BackgroundExecutor,
708) -> anyhow::Result<Vec<Example>> {
709 let mut captured_after_timestamps = Vec::new();
710 let mut rejected_after_timestamps = Vec::new();
711 let mut requested_after_timestamps = Vec::new();
712 let mut settled_after_timestamps = Vec::new();
713 let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
714 Vec::new();
715 let mut file_inputs = Vec::new();
716
717 for input in &args.inputs {
718 let input_string = input.to_string_lossy();
719 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
720 captured_after_timestamps.push(timestamp.to_string());
721 } else if let Some(timestamp) =
722 pull_examples::parse_rejected_after_input(input_string.as_ref())
723 {
724 rejected_after_timestamps.push(timestamp.to_string());
725 } else if let Some(timestamp) =
726 pull_examples::parse_requested_after_input(input_string.as_ref())
727 {
728 requested_after_timestamps.push(timestamp.to_string());
729 } else if let Some(timestamp) = parse_settled_after_input(input_string.as_ref()) {
730 settled_after_timestamps.push(timestamp.to_string());
731 } else if let Some((timestamp, rating_filter)) =
732 pull_examples::parse_rated_after_input(input_string.as_ref())
733 {
734 rated_after_inputs.push((timestamp.to_string(), rating_filter));
735 } else {
736 file_inputs.push(input.clone());
737 }
738 }
739
740 let mut examples = read_example_files(&file_inputs);
741
742 // Apply offset to file examples first, then pass remaining offset to Snowflake.
743 let file_example_count = examples.len();
744 let remaining_offset = if let Some(offset) = args.offset {
745 if offset >= file_example_count {
746 examples.clear();
747 offset - file_example_count
748 } else {
749 examples.splice(0..offset, []);
750 0
751 }
752 } else {
753 0
754 };
755
756 Progress::global().set_total_examples(examples.len());
757
758 let remaining_limit_for_snowflake =
759 args.limit.map(|limit| limit.saturating_sub(examples.len()));
760
761 if let Some(0) = remaining_limit_for_snowflake {
762 log::info!(
763 "skipping Snowflake inputs because --limit is already satisfied by example files"
764 );
765 } else {
766 let max_rows_per_timestamp = remaining_limit_for_snowflake;
767
768 if !rejected_after_timestamps.is_empty() {
769 rejected_after_timestamps.sort();
770
771 let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
772 http_client.clone(),
773 &rejected_after_timestamps,
774 max_rows_per_timestamp,
775 remaining_offset,
776 background_executor.clone(),
777 Some(MIN_CAPTURE_VERSION),
778 )
779 .await?;
780 examples.append(&mut rejected_examples);
781 }
782
783 if !requested_after_timestamps.is_empty() {
784 requested_after_timestamps.sort();
785
786 let mut requested_examples = pull_examples::fetch_requested_examples_after(
787 http_client.clone(),
788 &requested_after_timestamps,
789 max_rows_per_timestamp,
790 remaining_offset,
791 background_executor.clone(),
792 Some(MIN_CAPTURE_VERSION),
793 )
794 .await?;
795 examples.append(&mut requested_examples);
796 }
797
798 if !captured_after_timestamps.is_empty() {
799 captured_after_timestamps.sort();
800
801 let mut captured_examples = pull_examples::fetch_captured_examples_after(
802 http_client.clone(),
803 &captured_after_timestamps,
804 max_rows_per_timestamp,
805 remaining_offset,
806 background_executor.clone(),
807 Some(MIN_CAPTURE_VERSION),
808 )
809 .await?;
810 examples.append(&mut captured_examples);
811 }
812
813 if !settled_after_timestamps.is_empty() {
814 settled_after_timestamps.sort();
815
816 let mut settled_examples = fetch_settled_examples_after(
817 http_client.clone(),
818 &settled_after_timestamps,
819 max_rows_per_timestamp,
820 remaining_offset,
821 background_executor.clone(),
822 Some(MIN_CAPTURE_VERSION),
823 )
824 .await?;
825 examples.append(&mut settled_examples);
826 }
827
828 if !rated_after_inputs.is_empty() {
829 rated_after_inputs.sort();
830
831 let mut rated_examples = pull_examples::fetch_rated_examples_after(
832 http_client,
833 &rated_after_inputs,
834 max_rows_per_timestamp,
835 remaining_offset,
836 background_executor,
837 Some(MIN_CAPTURE_VERSION),
838 )
839 .await?;
840 examples.append(&mut rated_examples);
841 }
842 }
843
844 crate::example::sort_examples_by_repo_and_rev(&mut examples);
845
846 if let Some(name_filter) = &args.name {
847 examples.retain(|example| example.spec.name.contains(name_filter));
848 }
849 if let Some(repo_filter) = &args.repo {
850 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
851 }
852
853 // Skip resume logic for --in-place since input and output are the same file,
854 // which would incorrectly treat all input examples as already processed.
855 if !args.in_place {
856 if let Some(path) = output_path
857 && let Some(command) = &args.command
858 {
859 resume_from_output(path, &mut examples, command);
860 }
861 }
862
863 if let Some(max_duplicates) = args.max_duplicates {
864 deduplicate_examples(&mut examples, max_duplicates);
865 }
866
867 if let Some(limit) = args.limit {
868 examples.truncate(limit);
869 }
870
871 let progress = Progress::global();
872 progress.set_total_examples(examples.len());
873 progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
874
875 Ok(examples)
876}
877
878fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
879 let mut hasher = collections::FxHasher::default();
880 spec.hash(&mut hasher);
881 hasher.finish()
882}
883
884fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>, command: &Command) {
885 let file = match File::open(path) {
886 Ok(f) => f,
887 Err(_) => return,
888 };
889
890 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
891
892 let reader = BufReader::new(file);
893 let mut kept_lines = Vec::new();
894 let mut kept_hashes = HashSet::default();
895
896 for line in reader.lines() {
897 let line = match line {
898 Ok(l) => l,
899 Err(_) => continue,
900 };
901
902 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
903 let hash = spec_hash(&output_example.spec);
904 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
905 let is_complete = match command {
906 Command::Qa(_) => output_example
907 .qa
908 .first()
909 .and_then(|q| q.as_ref())
910 .and_then(|q| q.confidence)
911 .is_some(),
912 Command::Repair(_) => output_example.predictions.iter().any(|p| {
913 p.provider == PredictionProvider::Repair && p.actual_patch.is_some()
914 }),
915 _ => true,
916 };
917 if is_complete {
918 kept_hashes.insert(hash);
919 kept_lines.push(line);
920 }
921 }
922 }
923 }
924
925 let total = examples.len();
926 let already_processed = kept_hashes.len();
927
928 eprintln!(
929 "Resuming: {}/{} examples already processed",
930 already_processed, total
931 );
932
933 let file = OpenOptions::new()
934 .write(true)
935 .truncate(true)
936 .open(path)
937 .expect("Failed to open output file for rewriting");
938 let mut writer = BufWriter::new(file);
939 for line in &kept_lines {
940 writeln!(writer, "{}", line).expect("Failed to write to output file");
941 }
942 writer.flush().expect("Failed to flush output file");
943
944 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
945}
946
947fn main() {
948 let args = EpArgs::parse();
949
950 if args.printenv {
951 ::util::shell_env::print_env();
952 return;
953 }
954
955 let output = args.output_path();
956
957 if args.markdown && output.is_none() {
958 eprintln!("--markdown requires -o to specify the output directory");
959 std::process::exit(1);
960 }
961
962 let command = match &args.command {
963 Some(cmd) => cmd.clone(),
964 None => {
965 EpArgs::command().print_help().unwrap();
966 return;
967 }
968 };
969
970 match &command {
971 Command::ImportBatch(import_args) => {
972 smol::block_on(async {
973 match import_args.provider {
974 BatchProvider::Anthropic => {
975 let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
976 .expect("Failed to create Anthropic client");
977 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
978 eprintln!("Error importing Anthropic batches: {:?}", e);
979 std::process::exit(1);
980 }
981 }
982 BatchProvider::Openai => {
983 let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
984 .expect("Failed to create OpenAI client");
985 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
986 eprintln!("Error importing OpenAI batches: {:?}", e);
987 std::process::exit(1);
988 }
989 }
990 }
991 println!(
992 "Successfully imported {} batch(es)",
993 import_args.batch_ids.len()
994 );
995 });
996 return;
997 }
998 Command::Clean => {
999 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
1000 return;
1001 }
1002 Command::PrintZetaFormats => {
1003 use strum::IntoEnumIterator as _;
1004 for format in ZetaFormat::iter() {
1005 println!("{}", format.to_string().to_lowercase());
1006 }
1007 return;
1008 }
1009
1010 Command::Synthesize(synth_args) => {
1011 let output_dir = if let Some(output_dir) = args.output {
1012 output_dir
1013 } else {
1014 let default_output_dir = env::current_dir()
1015 .unwrap()
1016 .join("crates/edit_prediction_cli/evals-generated");
1017 if default_output_dir.parent().unwrap().exists() {
1018 std::fs::create_dir(&default_output_dir).ok();
1019 default_output_dir
1020 } else {
1021 panic!("output dir is required");
1022 }
1023 };
1024 let config = SynthesizeConfig {
1025 repo_urls: synth_args.repos.clone(),
1026 count: synth_args.count,
1027 max_commits: synth_args.max_commits,
1028 output_dir,
1029 fresh: synth_args.fresh,
1030 };
1031 smol::block_on(async {
1032 if let Err(e) = run_synthesize(config).await {
1033 eprintln!("Error: {:?}", e);
1034 std::process::exit(1);
1035 }
1036 });
1037 return;
1038 }
1039 Command::SplitCommit(split_commit_args) => {
1040 if let Err(error) = split_commit::run_split_commit(
1041 split_commit_args,
1042 &args.inputs,
1043 output.as_ref(),
1044 args.failed,
1045 ) {
1046 eprintln!("{error:#}");
1047 std::process::exit(1);
1048 }
1049 return;
1050 }
1051 Command::TruncatePatch(truncate_args) => {
1052 if let Err(error) =
1053 truncate_expected_patch::run_truncate_expected_patch(truncate_args, &args.inputs)
1054 {
1055 eprintln!("{error:#}");
1056 std::process::exit(1);
1057 }
1058 return;
1059 }
1060 Command::Split(split_args) => {
1061 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
1062 eprintln!("{error:#}");
1063 std::process::exit(1);
1064 }
1065 return;
1066 }
1067 Command::FilterLanguages(filter_args) => {
1068 if let Err(error) =
1069 run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
1070 {
1071 eprintln!("{error:#}");
1072 std::process::exit(1);
1073 }
1074 return;
1075 }
1076
1077 _ => {}
1078 }
1079
1080 let http_client = Arc::new(ReqwestClient::new());
1081 let app = gpui_platform::headless().with_http_client(http_client);
1082
1083 app.run(move |cx| {
1084 let app_state = Arc::new(headless::init(cx));
1085 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
1086
1087 cx.spawn(async move |cx| {
1088 let result = async {
1089 let examples = load_examples(
1090 app_state.client.http_client(),
1091 &args,
1092 output.as_ref(),
1093 cx.background_executor().clone(),
1094 )
1095 .await?;
1096
1097 match &command {
1098 Command::Predict(args) | Command::Score(args) => {
1099 predict::sync_batches(args.provider.as_ref()).await?;
1100 }
1101 Command::Eval(args) => {
1102 predict::sync_batches(args.predict.provider.as_ref()).await?;
1103 }
1104 Command::Qa(args) => {
1105 qa::sync_batches(args).await?;
1106 }
1107 Command::Repair(args) => {
1108 repair::sync_batches(args).await?;
1109 }
1110 _ => (),
1111 }
1112
1113 let failfast_on_single_example = examples.len() == 1;
1114
1115 // For --markdown mode, create the output directory if it doesn't exist
1116 if args.markdown {
1117 let dir = output.as_ref().expect("--markdown requires -o");
1118 if !dir.exists() {
1119 std::fs::create_dir_all(dir)
1120 .expect("Failed to create markdown output directory");
1121 }
1122 }
1123
1124 // Set up JSONL output writer (not used in markdown mode)
1125 let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
1126 let mut in_place_temp_path: Option<PathBuf> = None;
1127 if !args.markdown
1128 && let Some(output_path) = output.as_ref()
1129 {
1130 let write_path = if args.in_place {
1131 let temp = output_path.with_extension("jsonl.tmp");
1132 in_place_temp_path = Some(temp.clone());
1133 temp
1134 } else {
1135 output_path.clone()
1136 };
1137
1138 let file = OpenOptions::new()
1139 .create(true)
1140 .write(true)
1141 .truncate(args.in_place)
1142 .append(!args.in_place)
1143 .open(&write_path)
1144 .expect("Failed to open output file");
1145
1146 let mut writer = BufWriter::new(file);
1147 let (sender, mut receiver) = mpsc::unbounded::<String>();
1148 cx.background_spawn(async move {
1149 while let Some(line) = receiver.next().await {
1150 writeln!(writer, "{}", line).expect("Failed to write example");
1151 writer.flush().expect("Failed to flush output");
1152 }
1153 })
1154 .detach();
1155 output_sender = Some(sender);
1156 }
1157
1158 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
1159 let finished_examples = Mutex::new(Vec::new());
1160
1161 let mut tasks = Vec::new();
1162 for _ in 0..args.max_parallelism {
1163 tasks.push(async {
1164 loop {
1165 let Some(mut repo_examples) =
1166 grouped_examples.lock().unwrap().pop_front()
1167 else {
1168 break;
1169 };
1170 for example in &mut repo_examples {
1171 let example_progress =
1172 Progress::global().start_group(&example.spec.name);
1173
1174 let result = async {
1175 match &command {
1176 Command::Read(_) => {}
1177 Command::LoadProject => {
1178 run_load_project(
1179 example,
1180 app_state.clone(),
1181 &example_progress,
1182 cx.clone(),
1183 )
1184 .await?;
1185 }
1186 Command::Context => {
1187 run_context_retrieval(
1188 example,
1189 app_state.clone(),
1190 &example_progress,
1191 cx.clone(),
1192 )
1193 .await?;
1194 }
1195 Command::FormatPrompt(args) => {
1196 run_format_prompt(
1197 example,
1198 args,
1199 app_state.clone(),
1200 &example_progress,
1201 cx.clone(),
1202 )
1203 .await?;
1204 }
1205 Command::Predict(args) => {
1206 run_prediction(
1207 example,
1208 args,
1209 app_state.clone(),
1210 &example_progress,
1211 cx.clone(),
1212 )
1213 .await?;
1214 }
1215 Command::ParseOutput => {
1216 parse_output::run_parse_output(example)?;
1217 }
1218 Command::Distill => {
1219 run_distill(example).await?;
1220 }
1221 Command::Score(args) => {
1222 run_scoring(
1223 example,
1224 args,
1225 app_state.clone(),
1226 &example_progress,
1227 cx.clone(),
1228 )
1229 .await?;
1230 }
1231 Command::Eval(args) => {
1232 run_scoring(
1233 example,
1234 &args.predict,
1235 app_state.clone(),
1236 &example_progress,
1237 cx.clone(),
1238 )
1239 .await?;
1240 }
1241 Command::Qa(args) => {
1242 qa::run_qa(example, args, &example_progress).await?;
1243 }
1244 Command::Repair(args) => {
1245 repair::run_repair(example, args, &example_progress)
1246 .await?;
1247 }
1248 Command::Clean
1249 | Command::Synthesize(_)
1250 | Command::SplitCommit(_)
1251 | Command::Split(_)
1252 | Command::TruncatePatch(_)
1253 | Command::FilterLanguages(_)
1254 | Command::ImportBatch(_)
1255 | Command::PrintZetaFormats => {
1256 unreachable!()
1257 }
1258 }
1259 anyhow::Ok(())
1260 }
1261 .await;
1262
1263 let failed = if let Err(error) = result {
1264 handle_error(
1265 error,
1266 &args,
1267 &command,
1268 &app_state,
1269 failfast_on_single_example,
1270 &example,
1271 )
1272 .await;
1273 true
1274 } else {
1275 false
1276 };
1277
1278 let should_write = !failed || args.failed == FailedHandling::Keep;
1279 if should_write {
1280 if args.markdown {
1281 let markdown_dir =
1282 output.as_ref().expect("--markdown requires -o");
1283 let filename = format!("{}.md", example.spec.filename());
1284 let path = markdown_dir.join(&filename);
1285 let markdown = example.spec.to_markdown();
1286 std::fs::write(&path, &markdown)
1287 .expect("Failed to write markdown file");
1288 } else if let Some(ref mut sender) = output_sender.clone() {
1289 let line = serde_json::to_string(&example).unwrap();
1290 sender
1291 .send(line)
1292 .await
1293 .expect("Failed to send to output writer");
1294 } else if args.output.is_none()
1295 && !matches!(command, Command::Eval(_))
1296 {
1297 let line = serde_json::to_string(&example).unwrap();
1298 println!("{}", line);
1299 }
1300 }
1301 }
1302
1303 let project = repo_examples
1304 .iter()
1305 .find_map(|e| e.state.as_ref().map(|s| s.project.clone()));
1306
1307 if let Some(project) = project {
1308 let mut cx = cx.clone();
1309
1310 let shutdown_task: Task<()> =
1311 project.update(&mut cx, |project, cx| {
1312 let lsp_store = project.lsp_store();
1313 lsp_store.update(cx, |lsp_store, cx| {
1314 lsp_store.shutdown_all_language_servers(cx)
1315 })
1316 });
1317
1318 shutdown_task.await;
1319
1320 if let Some(ep_store) =
1321 cx.update(|cx| EditPredictionStore::try_global(cx))
1322 {
1323 ep_store.update(&mut cx, |store, _| {
1324 store.remove_project(&project);
1325 });
1326 }
1327 }
1328
1329 for example in &mut repo_examples {
1330 example.state.take();
1331 }
1332 finished_examples
1333 .lock()
1334 .unwrap()
1335 .extend_from_slice(&repo_examples);
1336 }
1337 });
1338 }
1339 futures::future::join_all(tasks).await;
1340
1341 Progress::global().finalize();
1342
1343 let is_markdown = args.markdown;
1344 let write_path = in_place_temp_path.as_ref().or(output.as_ref());
1345 match &command {
1346 Command::Predict(args) | Command::Score(args) => {
1347 predict::sync_batches(args.provider.as_ref()).await?;
1348 if args.wait {
1349 predict::wait_for_batches(args.provider.as_ref()).await?;
1350 let mut examples =
1351 std::mem::take(&mut *finished_examples.lock().unwrap());
1352 predict::reprocess_after_batch_wait(&mut examples, args).await?;
1353 rewrite_output(&examples, write_path, is_markdown)?;
1354 *finished_examples.lock().unwrap() = examples;
1355 }
1356 }
1357 Command::Eval(args) => {
1358 predict::sync_batches(args.predict.provider.as_ref()).await?;
1359 if args.predict.wait {
1360 predict::wait_for_batches(args.predict.provider.as_ref()).await?;
1361 let mut examples =
1362 std::mem::take(&mut *finished_examples.lock().unwrap());
1363 predict::reprocess_after_batch_wait(&mut examples, &args.predict)
1364 .await?;
1365 rewrite_output(&examples, write_path, is_markdown)?;
1366 *finished_examples.lock().unwrap() = examples;
1367 }
1368 }
1369 Command::Qa(args) => {
1370 qa::sync_batches(args).await?;
1371 }
1372 Command::Repair(args) => {
1373 repair::sync_batches(args).await?;
1374 if args.wait {
1375 repair::wait_for_batches(args).await?;
1376 let mut examples =
1377 std::mem::take(&mut *finished_examples.lock().unwrap());
1378 repair::reprocess_after_batch_wait(&mut examples, args).await?;
1379 rewrite_output(&examples, write_path, is_markdown)?;
1380 *finished_examples.lock().unwrap() = examples;
1381 }
1382 }
1383 _ => (),
1384 }
1385
1386 match &command {
1387 Command::Eval(args) => {
1388 let examples = finished_examples.lock().unwrap();
1389 score::print_report(&examples, args.verbose);
1390 if let Some(summary_path) = &args.summary_json {
1391 score::write_summary_json(&examples, summary_path)?;
1392 }
1393 }
1394 Command::Repair(args) => {
1395 let examples = finished_examples.lock().unwrap();
1396 repair::print_report(&examples, args.confidence_threshold);
1397 }
1398 _ => (),
1399 };
1400
1401 // For --in-place, atomically rename temp file to original
1402 if let Some(temp_path) = &in_place_temp_path {
1403 let final_path = output.as_ref().expect("in_place_temp_path requires output");
1404 std::fs::rename(temp_path, final_path)
1405 .expect("Failed to rename temp file to final output");
1406 }
1407
1408 anyhow::Ok(())
1409 }
1410 .await;
1411
1412 if let Err(e) = result {
1413 panic!("Fatal error: {:?}", e);
1414 }
1415
1416 let _ = cx.update(|cx| cx.quit());
1417 })
1418 .detach();
1419 });
1420}
1421
1422fn rewrite_output(
1423 examples: &[Example],
1424 output_path: Option<&PathBuf>,
1425 markdown: bool,
1426) -> anyhow::Result<()> {
1427 if markdown {
1428 let dir = output_path.context("--markdown requires -o")?;
1429 for example in examples {
1430 let filename = format!("{}.md", example.spec.filename());
1431 let path = dir.join(&filename);
1432 let markdown = example.spec.to_markdown();
1433 std::fs::write(&path, &markdown).context("Failed to write markdown file")?;
1434 }
1435 } else if let Some(path) = output_path {
1436 let file = OpenOptions::new()
1437 .create(true)
1438 .write(true)
1439 .truncate(true)
1440 .open(path)
1441 .context("Failed to open output file for rewriting")?;
1442 let mut writer = BufWriter::new(file);
1443 for example in examples {
1444 let line = serde_json::to_string(example)?;
1445 writeln!(writer, "{}", line)?;
1446 }
1447 writer.flush()?;
1448 } else {
1449 for example in examples {
1450 let line = serde_json::to_string(example)?;
1451 println!("{}", line);
1452 }
1453 }
1454 Ok(())
1455}
1456
1457async fn handle_error(
1458 error: anyhow::Error,
1459 args: &EpArgs,
1460 command: &Command,
1461 app_state: &Arc<headless::EpAppState>,
1462 failfast_on_single_example: bool,
1463 example: &Example,
1464) {
1465 Progress::global().increment_failed();
1466
1467 let msg;
1468 if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1469 let example_name = example.spec.filename();
1470
1471 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1472 app_state
1473 .fs
1474 .write(
1475 &failed_example_path,
1476 &serde_json::to_vec_pretty(&example).unwrap(),
1477 )
1478 .await
1479 .unwrap();
1480 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1481 app_state
1482 .fs
1483 .write(&err_path, format!("{error:?}").as_bytes())
1484 .await
1485 .unwrap();
1486
1487 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1488 let mut file = OpenOptions::new()
1489 .create(true)
1490 .append(true)
1491 .open(&failed_jsonl_path)
1492 .expect("Failed to open failed.jsonl");
1493 writeln!(file, "{}", serde_json::to_string(example).unwrap())
1494 .expect("Failed to write to failed.jsonl");
1495
1496 let cursor_path = match example.repo_name() {
1497 Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1498 Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1499 };
1500 msg = format!(
1501 indoc::indoc! {"
1502 While processing \"{}\":
1503
1504 \x1b[31m{:?}\x1b[0m
1505
1506 Example: \x1b[36m{}\x1b[0m
1507 Error file: \x1b[36m{}\x1b[0m
1508 Cursor file: \x1b[36m{}\x1b[0m
1509 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1510 "},
1511 example.spec.name,
1512 error,
1513 failed_example_path.display(),
1514 err_path.display(),
1515 cursor_path.display(),
1516 command,
1517 failed_example_path.display(),
1518 );
1519 } else {
1520 msg = format!(
1521 indoc::indoc! {"
1522 While processing \"{}\":
1523
1524 \x1b[31m{:?}\x1b[0m
1525 "},
1526 example.spec.name, error
1527 );
1528 }
1529
1530 if args.failfast || failfast_on_single_example {
1531 Progress::global().finalize();
1532 panic!("{}", msg);
1533 } else {
1534 log::error!("{}", msg);
1535 }
1536}