1mod anthropic_client;
2mod distill;
3mod example;
4mod filter_languages;
5mod format_prompt;
6mod git;
7mod headless;
8mod load_project;
9mod metrics;
10mod openai_client;
11mod parse_output;
12mod paths;
13mod predict;
14mod progress;
15mod prompt_assets;
16mod pull_examples;
17mod qa;
18mod reorder_patch;
19mod repair;
20mod retrieve_context;
21mod reversal_tracking;
22mod score;
23mod split_commit;
24mod split_dataset;
25
26mod synthesize;
27mod truncate_expected_patch;
28mod word_diff;
29use anyhow::Context as _;
30use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
31use collections::{HashMap, HashSet};
32use edit_prediction::EditPredictionStore;
33use futures::channel::mpsc;
34use futures::{SinkExt as _, StreamExt as _};
35use gaoya::minhash::{
36 MinHashIndex, MinHasher, MinHasher32, calculate_minhash_params, compute_minhash_similarity,
37};
38use gpui::{AppContext as _, BackgroundExecutor, Task};
39use zeta_prompt::ZetaFormat;
40
41use reqwest_client::ReqwestClient;
42use serde::{Deserialize, Deserializer, Serialize, Serializer};
43use std::env;
44use std::fmt::Display;
45use std::fs::{File, OpenOptions};
46use std::hash::{Hash, Hasher};
47use std::io::{BufRead, BufReader, BufWriter, Write};
48use std::sync::Mutex;
49use std::{path::PathBuf, sync::Arc};
50
51use crate::distill::run_distill;
52use crate::example::{Example, group_examples_by_repo, read_example_files};
53use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
54use crate::format_prompt::run_format_prompt;
55use crate::load_project::run_load_project;
56use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
57use crate::predict::run_prediction;
58use crate::progress::Progress;
59use crate::pull_examples::{fetch_settled_examples_after, parse_settled_after_input};
60use crate::retrieve_context::run_context_retrieval;
61use crate::score::run_scoring;
62use crate::split_commit::SplitCommitArgs;
63use crate::split_dataset::SplitArgs;
64use crate::synthesize::{SynthesizeConfig, run_synthesize};
65use crate::truncate_expected_patch::TruncatePatchArgs;
66
67#[derive(Parser, Debug)]
68#[command(name = "ep")]
69struct EpArgs {
70 #[arg(long, default_value_t = false)]
71 printenv: bool,
72 #[clap(long, default_value_t = 10, global = true)]
73 max_parallelism: usize,
74 /// The limit for the number of examples to process
75 /// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
76 #[clap(long, global = true)]
77 limit: Option<usize>,
78 #[clap(long, global = true)]
79 offset: Option<usize>,
80 /// Filter examples by name
81 #[clap(long, global = true)]
82 name: Option<String>,
83 /// Filter examples by repository
84 #[clap(long, global = true)]
85 repo: Option<String>,
86 /// Deduplicate by cursor position and keep at most this many examples per cluster
87 #[clap(long, global = true)]
88 max_duplicates: Option<usize>,
89 #[command(subcommand)]
90 command: Option<Command>,
91 /// Input file paths
92 #[clap(global = true)]
93 inputs: Vec<PathBuf>,
94 #[arg(long, short, global = true)]
95 output: Option<PathBuf>,
96 #[arg(long, short, global = true)]
97 in_place: bool,
98 #[arg(long, short, global = true)]
99 failfast: bool,
100 /// How to handle failed examples in output: keep them or skip them.
101 /// Failed examples are always logged to the run's failed directory.
102 #[arg(long, global = true, default_value = "keep")]
103 failed: FailedHandling,
104 /// Output as markdown files instead of JSONL. When set, -o specifies a directory
105 /// where one .md file per example will be written (named after each example).
106 #[arg(long, short, global = true)]
107 markdown: bool,
108}
109
110/// Controls whether failed examples are included in the main output.
111/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
112#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
113pub enum FailedHandling {
114 /// Include failed examples in the main output (default)
115 #[default]
116 Keep,
117 /// Exclude failed examples from the main output
118 Skip,
119 /// Skip writing files
120 SkipNoFiles,
121}
122
123const INPUTS_HELP: &str = r#"
124Inputs can be file paths or special specifiers:
125
126 path
127 Path to an example(s) file (.md, .json, or .jsonl)
128
129 captured-after:{timestamp}
130 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
131 These are examples captured via the "Capture Edit Prediction Example" action.
132
133 rejected-after:{timestamp}
134 Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
135 These are predictions that were shown to users but rejected (useful for DPO training).
136
137 settled-after:{timestamp}
138 Fetch settled stream examples from Snowflake after the given RFC3339 timestamp.
139 These are examples from the edit prediction settled stream.
140
141 rated-after:{timestamp}
142 Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
143 These are predictions that users explicitly rated as positive or negative via the
144 rate completions modal. Only zeta2 predictions are included.
145 - Positive ratings: output becomes expected_patches
146 - Negative ratings: output becomes rejected_patch
147
148 rated-positive-after:{timestamp}
149 Same as rated-after, but only fetches positively rated predictions.
150
151 rated-negative-after:{timestamp}
152 Same as rated-after, but only fetches negatively rated predictions.
153
154 Required environment variables to connect to Snowflake:
155 EP_SNOWFLAKE_API_KEY
156 EP_SNOWFLAKE_BASE_URL
157
158 Optional:
159 EP_SNOWFLAKE_ROLE
160
161Examples:
162
163 # Read examples from a file
164 ep read examples.jsonl -o output.jsonl
165
166 # Read captured examples after a timestamp
167 ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
168
169 # Read rejected predictions for DPO training
170 ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
171
172 # Read user-rated predictions
173 ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
174
175 # Read settled stream examples
176 ep read settled-after:2025-01-01T00:00:00Z -o settled.jsonl
177
178 # Read only positively rated predictions
179 ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
180
181 # Read only negatively rated predictions
182 ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
183
184 # Mix multiple input sources
185 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
186"#;
187
188#[derive(Subcommand, Debug, Clone)]
189enum Command {
190 /// Read examples from files or fetch from Snowflake, output as .jsonl
191 Read(ReadArgs),
192 /// Create git worktrees for each example and load file contents
193 LoadProject,
194 /// Retrieve context for input examples.
195 Context,
196 /// Generate a prompt string for a specific model
197 FormatPrompt(FormatPromptArgs),
198 /// Runs edit prediction
199 Predict(PredictArgs),
200 /// Parse model outputs (actual_output) into unified diffs (actual_patch).
201 /// Requires format-prompt to have been run first. Uses provider from prompt.
202 ParseOutput,
203 /// Computes a score based on actual and expected patches
204 Score(PredictArgs),
205 /// Prepares a distillation dataset by copying expected outputs to
206 /// predicted outputs and removing actual outputs and prompts.
207 Distill,
208 /// Print aggregated scores
209 Eval(EvalArgs),
210 /// Generate eval examples by analyzing git commits from a repository
211 Synthesize(SynthesizeArgs),
212 /// Remove git repositories and worktrees
213 Clean,
214 /// Generate an evaluation example by splitting a chronologically-ordered commit
215 SplitCommit(SplitCommitArgs),
216 /// Truncate expected patch by the given criteria
217 TruncatePatch(TruncatePatchArgs),
218 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
219 Split(SplitArgs),
220 /// Filter a JSONL dataset by programming language (based on cursor_path extension)
221 FilterLanguages(FilterLanguagesArgs),
222 /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
223 ImportBatch(ImportBatchArgs),
224 /// Assess the quality of predictions using LLM-as-a-judge
225 Qa(qa::QaArgs),
226 /// Repair predictions that received poor QA scores by generating improved predictions
227 Repair(repair::RepairArgs),
228 /// Print all valid zeta formats (lowercase, one per line)
229 PrintZetaFormats,
230}
231
232impl Display for Command {
233 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234 match self {
235 Command::Read(_) => write!(f, "read"),
236 Command::LoadProject => write!(f, "load-project"),
237 Command::Context => write!(f, "context"),
238 Command::FormatPrompt(args) => {
239 write!(f, "format-prompt --provider={}", args.provider)
240 }
241 Command::Predict(args) => match &args.provider {
242 Some(provider) => write!(f, "predict --provider={}", provider),
243 None => write!(f, "predict"),
244 },
245 Command::ParseOutput => write!(f, "parse-output"),
246 Command::Score(args) => match &args.provider {
247 Some(provider) => write!(f, "score --provider={}", provider),
248 None => write!(f, "score"),
249 },
250 Command::Distill => write!(f, "distill"),
251 Command::Eval(args) => match &args.predict.provider {
252 Some(provider) => write!(f, "eval --provider={}", provider),
253 None => write!(f, "eval"),
254 },
255 Command::Synthesize(args) => {
256 write!(f, "synthesize --repos {}", args.repos.join(" "))
257 }
258 Command::Clean => write!(f, "clean"),
259 Command::SplitCommit(_) => write!(f, "split-commit"),
260 Command::TruncatePatch(_) => write!(f, "truncate-patch"),
261 Command::Split(_) => write!(f, "split"),
262 Command::FilterLanguages(_) => write!(f, "filter-languages"),
263 Command::ImportBatch(args) => {
264 write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
265 }
266 Command::Qa(_) => {
267 write!(f, "qa")
268 }
269 Command::Repair(_) => {
270 write!(f, "repair")
271 }
272 Command::PrintZetaFormats => {
273 write!(f, "print-zeta-formats")
274 }
275 }
276 }
277}
278
279#[derive(Debug, Args, Clone)]
280#[command(after_help = INPUTS_HELP)]
281struct ReadArgs {}
282
283#[derive(Debug, Args, Clone)]
284struct FormatPromptArgs {
285 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
286 provider: PredictionProvider,
287}
288
289#[derive(Debug, Args, Clone)]
290struct PredictArgs {
291 #[clap(long, short('p'))]
292 provider: Option<PredictionProvider>,
293 #[clap(long, default_value_t = 1)]
294 repetitions: usize,
295 /// Only use cached responses, don't queue new requests for batching
296 #[clap(long)]
297 cache_only: bool,
298 /// Wait for all batches to complete before exiting (only applies to batched providers like teacher)
299 #[clap(long)]
300 wait: bool,
301}
302
303#[derive(Debug, Args, Clone)]
304struct EvalArgs {
305 #[clap(flatten)]
306 predict: PredictArgs,
307 /// Path to write summary scores as JSON
308 #[clap(long)]
309 summary_json: Option<PathBuf>,
310 /// Print all individual example lines (default: up to 20)
311 #[clap(long)]
312 verbose: bool,
313}
314
315#[derive(Clone, Copy, Default, Debug, PartialEq, Eq, Hash)]
316pub enum TeacherBackend {
317 Sonnet46,
318 #[default]
319 Sonnet45,
320 Gpt52,
321}
322
323impl std::fmt::Display for TeacherBackend {
324 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325 match self {
326 TeacherBackend::Sonnet46 => write!(f, "sonnet46"),
327 TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
328 TeacherBackend::Gpt52 => write!(f, "gpt52"),
329 }
330 }
331}
332
333impl std::str::FromStr for TeacherBackend {
334 type Err = anyhow::Error;
335
336 fn from_str(s: &str) -> Result<Self, Self::Err> {
337 match s.to_lowercase().as_str() {
338 "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
339 "sonnet46" => Ok(TeacherBackend::Sonnet46),
340 "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
341 "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
342 _ => anyhow::bail!(
343 "unknown teacher backend `{s}`. Valid options: sonnet45, sonnet46, gpt52"
344 ),
345 }
346 }
347}
348
349impl TeacherBackend {
350 pub fn model_name(&self) -> &'static str {
351 match self {
352 TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
353 TeacherBackend::Sonnet46 => "claude-sonnet-4-6",
354 TeacherBackend::Gpt52 => "gpt-5.2",
355 }
356 }
357}
358
359#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
360enum PredictionProvider {
361 Sweep,
362 Mercury,
363 Zeta1,
364 Zeta2(ZetaFormat),
365 Baseten(ZetaFormat),
366 Teacher(TeacherBackend),
367 TeacherMultiRegion(TeacherBackend),
368 TeacherNonBatching(TeacherBackend),
369 TeacherMultiRegionNonBatching(TeacherBackend),
370 Repair,
371}
372
373impl Default for PredictionProvider {
374 fn default() -> Self {
375 PredictionProvider::Zeta2(ZetaFormat::default())
376 }
377}
378
379impl std::fmt::Display for PredictionProvider {
380 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
381 match self {
382 PredictionProvider::Sweep => write!(f, "sweep"),
383 PredictionProvider::Mercury => write!(f, "mercury"),
384 PredictionProvider::Zeta1 => write!(f, "zeta1"),
385 PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"),
386 PredictionProvider::Baseten(format) => write!(f, "baseten:{format}"),
387 PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
388 PredictionProvider::TeacherMultiRegion(backend) => {
389 write!(f, "teacher-multi-region:{backend}")
390 }
391 PredictionProvider::TeacherNonBatching(backend) => {
392 write!(f, "teacher-non-batching:{backend}")
393 }
394 PredictionProvider::TeacherMultiRegionNonBatching(backend) => {
395 write!(f, "teacher-multi-region-non-batching:{backend}")
396 }
397 PredictionProvider::Repair => write!(f, "repair"),
398 }
399 }
400}
401
402impl std::str::FromStr for PredictionProvider {
403 type Err = anyhow::Error;
404
405 fn from_str(s: &str) -> Result<Self, Self::Err> {
406 let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
407
408 let provider_lower = provider.to_lowercase();
409 match provider_lower.as_str() {
410 "sweep" => Ok(PredictionProvider::Sweep),
411 "mercury" => Ok(PredictionProvider::Mercury),
412 "zeta1" => Ok(PredictionProvider::Zeta1),
413 "zeta2" => {
414 let format = arg.map(ZetaFormat::parse).transpose()?.unwrap_or_default();
415 Ok(PredictionProvider::Zeta2(format))
416 }
417 "teacher" => {
418 let backend = arg
419 .map(|a| a.parse())
420 .transpose()?
421 .unwrap_or(TeacherBackend::default());
422 Ok(PredictionProvider::Teacher(backend))
423 }
424 "teacher-multi-region" | "teacher_multi_region" => {
425 let backend = arg
426 .map(|a| a.parse())
427 .transpose()?
428 .unwrap_or(TeacherBackend::default());
429 Ok(PredictionProvider::TeacherMultiRegion(backend))
430 }
431 "teacher-non-batching" | "teacher_non_batching" => {
432 let backend = arg
433 .map(|a| a.parse())
434 .transpose()?
435 .unwrap_or(TeacherBackend::default());
436 Ok(PredictionProvider::TeacherNonBatching(backend))
437 }
438 "teacher-multi-region-non-batching" | "teacher_multi_region_non_batching" => {
439 let backend = arg
440 .map(|a| a.parse())
441 .transpose()?
442 .unwrap_or(TeacherBackend::default());
443 Ok(PredictionProvider::TeacherMultiRegionNonBatching(backend))
444 }
445 "repair" => Ok(PredictionProvider::Repair),
446 "baseten" => {
447 let format = arg
448 .map(ZetaFormat::parse)
449 .transpose()?
450 .unwrap_or(ZetaFormat::default());
451 Ok(PredictionProvider::Baseten(format))
452 }
453 _ => {
454 anyhow::bail!(
455 "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-multi-region, teacher-multi-region:<backend>, teacher-non-batching, teacher-multi-region-non-batching, repair\n\
456 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
457 For teacher providers, you can specify a backend like `teacher:sonnet46`, `teacher-multi-region:sonnet46`, `teacher-multi-region-non-batching:sonnet46`, or `teacher:gpt52`.\n\
458 Available zeta versions:\n{}",
459 ZetaFormat::options_as_string()
460 )
461 }
462 }
463 }
464}
465
466impl Serialize for PredictionProvider {
467 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
468 where
469 S: Serializer,
470 {
471 serializer.serialize_str(&self.to_string())
472 }
473}
474
475impl<'de> Deserialize<'de> for PredictionProvider {
476 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
477 where
478 D: Deserializer<'de>,
479 {
480 let s = String::deserialize(deserializer)?;
481 s.parse().map_err(serde::de::Error::custom)
482 }
483}
484
485#[derive(Debug, Args, Clone)]
486struct SynthesizeArgs {
487 /// Repository URLs (git@github.com:owner/repo or https://...)
488 #[clap(long, required = true, num_args = 1..)]
489 repos: Vec<String>,
490
491 /// Number of examples to generate per repository
492 #[clap(long, default_value_t = 5)]
493 count: usize,
494
495 /// Maximum commits to scan per repository before giving up
496 #[clap(long, default_value_t = 100)]
497 max_commits: usize,
498
499 /// Ignore state file and reprocess all commits
500 #[clap(long)]
501 fresh: bool,
502}
503
504#[derive(Debug, Args, Clone)]
505struct ImportBatchArgs {
506 /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
507 #[clap(long, required = true, num_args = 1..)]
508 batch_ids: Vec<String>,
509 /// Which provider's batches to import (anthropic or openai)
510 #[clap(long, default_value = "anthropic")]
511 provider: BatchProvider,
512}
513
514#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
515enum BatchProvider {
516 Anthropic,
517 Openai,
518}
519
520#[cfg(test)]
521mod tests {
522 use super::*;
523
524 #[test]
525 fn prediction_provider_multi_region_non_batched_round_trips_to_primary_spelling() {
526 let provider: PredictionProvider = "teacher-multi-region-non-batching:sonnet46"
527 .parse()
528 .unwrap();
529 assert_eq!(
530 provider,
531 PredictionProvider::TeacherMultiRegionNonBatching(TeacherBackend::Sonnet46)
532 );
533 assert_eq!(
534 provider.to_string(),
535 "teacher-multi-region-non-batching:sonnet46"
536 );
537 }
538
539 #[test]
540 fn prediction_provider_multi_region_non_batched_alias_round_trips_to_primary_spelling() {
541 let provider: PredictionProvider =
542 "teacher_multi_region_non_batching:gpt52".parse().unwrap();
543 assert_eq!(
544 provider,
545 PredictionProvider::TeacherMultiRegionNonBatching(TeacherBackend::Gpt52)
546 );
547 assert_eq!(
548 provider.to_string(),
549 "teacher-multi-region-non-batching:gpt52"
550 );
551 }
552}
553
554impl EpArgs {
555 fn output_path(&self) -> Option<PathBuf> {
556 if self.in_place {
557 if self.inputs.len() == 1 {
558 self.inputs.first().cloned()
559 } else {
560 panic!("--in-place requires exactly one input file")
561 }
562 } else {
563 self.output.clone()
564 }
565 }
566}
567
568/// Minimum Zed version required for Snowflake queries.
569/// This version introduced the current request schema with predicted edits in the edit
570/// history, and open source repos distinguished.
571const MIN_CAPTURE_VERSION: pull_examples::MinCaptureVersion = pull_examples::MinCaptureVersion {
572 minor: 224,
573 patch: 1,
574};
575
576fn deduplicate_examples(examples: &mut Vec<Example>, max_per_cluster: usize) {
577 let total_before_exact = examples.len();
578 let mut seen_positions = HashSet::default();
579 examples.retain(|example| seen_positions.insert(example.spec.cursor_position.clone()));
580 log::info!(
581 "exact duplicate filter: {total_before_exact} examples → {} examples ({} removed)",
582 examples.len(),
583 total_before_exact - examples.len(),
584 );
585
586 const JACCARD_THRESHOLD: f64 = 0.5;
587 const NUM_HASHES: usize = 128;
588 const TOKEN_NGRAM_SIZE: usize = 5;
589
590 let (num_bands, band_width) = calculate_minhash_params(JACCARD_THRESHOLD, NUM_HASHES);
591 let num_hashes = num_bands * band_width;
592 let minhasher = MinHasher32::new(num_hashes);
593 let mut index: MinHashIndex<u32, usize> =
594 MinHashIndex::new(num_bands, band_width, JACCARD_THRESHOLD);
595
596 let signatures: Vec<Vec<u32>> = examples
597 .iter()
598 .map(|example| {
599 let shingles = code_token_ngrams(&example.spec.cursor_position, TOKEN_NGRAM_SIZE);
600 minhasher.create_signature(shingles.iter())
601 })
602 .collect();
603
604 for (id, signature) in signatures.iter().enumerate() {
605 index.insert(id, signature.clone());
606 }
607
608 // Build clusters via union-find on LSH candidate pairs.
609 let mut parent: Vec<usize> = (0..examples.len()).collect();
610
611 fn find(parent: &mut Vec<usize>, mut x: usize) -> usize {
612 while parent[x] != x {
613 parent[x] = parent[parent[x]];
614 x = parent[x];
615 }
616 x
617 }
618
619 for (id, signature) in signatures.iter().enumerate() {
620 for candidate in index.query_owned(signature) {
621 let (a, b) = (find(&mut parent, id), find(&mut parent, candidate));
622 if a != b {
623 parent[a] = b;
624 }
625 }
626 }
627
628 let mut clusters: HashMap<usize, Vec<usize>> = HashMap::default();
629 for id in 0..examples.len() {
630 clusters.entry(find(&mut parent, id)).or_default().push(id);
631 }
632
633 let mut keep: HashSet<usize> = HashSet::default();
634 for members in clusters.values() {
635 let selected = greedy_max_min_diverse(members, &signatures, max_per_cluster);
636 keep.extend(selected);
637 }
638
639 let total = examples.len();
640 let mut kept_indices: Vec<usize> = keep.into_iter().collect();
641 kept_indices.sort();
642
643 let mut retained = Vec::with_capacity(kept_indices.len());
644 for index in kept_indices.into_iter().rev() {
645 retained.push(examples.swap_remove(index));
646 }
647 retained.reverse();
648
649 *examples = retained;
650 log::info!(
651 "near-duplicate filter: {total} examples → {} examples ({} removed)",
652 examples.len(),
653 total - examples.len(),
654 );
655}
656
657fn greedy_max_min_diverse(members: &[usize], signatures: &[Vec<u32>], k: usize) -> Vec<usize> {
658 if members.len() <= k {
659 return members.to_vec();
660 }
661
662 let mut selected = vec![members[0]];
663 let mut min_dist: HashMap<usize, f64> = HashMap::default();
664 for &member in &members[1..] {
665 let dist = 1.0 - compute_minhash_similarity(&signatures[selected[0]], &signatures[member]);
666 min_dist.insert(member, dist);
667 }
668
669 while selected.len() < k {
670 let &best = min_dist
671 .iter()
672 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
673 .map(|(id, _)| id)
674 .expect("min_dist should not be empty when selected.len() < k");
675 selected.push(best);
676 min_dist.remove(&best);
677
678 let best_sig = &signatures[best];
679 for (member, current_min) in min_dist.iter_mut() {
680 let dist = 1.0 - compute_minhash_similarity(best_sig, &signatures[*member]);
681 if dist < *current_min {
682 *current_min = dist;
683 }
684 }
685 }
686
687 selected
688}
689
690fn code_token_ngrams(code: &str, ngram_size: usize) -> Vec<String> {
691 let tokens: Vec<&str> = word_diff::tokenize(code)
692 .into_iter()
693 .filter(|t| !t.trim().is_empty())
694 .collect();
695
696 if tokens.len() < ngram_size {
697 return vec![tokens.join("\0")];
698 }
699
700 tokens
701 .windows(ngram_size)
702 .map(|window| window.join("\0"))
703 .collect()
704}
705
706async fn load_examples(
707 http_client: Arc<dyn http_client::HttpClient>,
708 args: &EpArgs,
709 output_path: Option<&PathBuf>,
710 background_executor: BackgroundExecutor,
711) -> anyhow::Result<Vec<Example>> {
712 let mut captured_after_timestamps = Vec::new();
713 let mut rejected_after_timestamps = Vec::new();
714 let mut requested_after_timestamps = Vec::new();
715 let mut settled_after_timestamps = Vec::new();
716 let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
717 Vec::new();
718 let mut file_inputs = Vec::new();
719
720 for input in &args.inputs {
721 let input_string = input.to_string_lossy();
722 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
723 captured_after_timestamps.push(timestamp.to_string());
724 } else if let Some(timestamp) =
725 pull_examples::parse_rejected_after_input(input_string.as_ref())
726 {
727 rejected_after_timestamps.push(timestamp.to_string());
728 } else if let Some(timestamp) =
729 pull_examples::parse_requested_after_input(input_string.as_ref())
730 {
731 requested_after_timestamps.push(timestamp.to_string());
732 } else if let Some(timestamp) = parse_settled_after_input(input_string.as_ref()) {
733 settled_after_timestamps.push(timestamp.to_string());
734 } else if let Some((timestamp, rating_filter)) =
735 pull_examples::parse_rated_after_input(input_string.as_ref())
736 {
737 rated_after_inputs.push((timestamp.to_string(), rating_filter));
738 } else {
739 file_inputs.push(input.clone());
740 }
741 }
742
743 let mut examples = read_example_files(&file_inputs);
744
745 // Apply offset to file examples first, then pass remaining offset to Snowflake.
746 let file_example_count = examples.len();
747 let remaining_offset = if let Some(offset) = args.offset {
748 if offset >= file_example_count {
749 examples.clear();
750 offset - file_example_count
751 } else {
752 examples.splice(0..offset, []);
753 0
754 }
755 } else {
756 0
757 };
758
759 Progress::global().set_total_examples(examples.len());
760
761 let remaining_limit_for_snowflake =
762 args.limit.map(|limit| limit.saturating_sub(examples.len()));
763
764 if let Some(0) = remaining_limit_for_snowflake {
765 log::info!(
766 "skipping Snowflake inputs because --limit is already satisfied by example files"
767 );
768 } else {
769 let max_rows_per_timestamp = remaining_limit_for_snowflake;
770
771 if !rejected_after_timestamps.is_empty() {
772 rejected_after_timestamps.sort();
773
774 let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
775 http_client.clone(),
776 &rejected_after_timestamps,
777 max_rows_per_timestamp,
778 remaining_offset,
779 background_executor.clone(),
780 Some(MIN_CAPTURE_VERSION),
781 )
782 .await?;
783 examples.append(&mut rejected_examples);
784 }
785
786 if !requested_after_timestamps.is_empty() {
787 requested_after_timestamps.sort();
788
789 let mut requested_examples = pull_examples::fetch_requested_examples_after(
790 http_client.clone(),
791 &requested_after_timestamps,
792 max_rows_per_timestamp,
793 remaining_offset,
794 background_executor.clone(),
795 Some(MIN_CAPTURE_VERSION),
796 )
797 .await?;
798 examples.append(&mut requested_examples);
799 }
800
801 if !captured_after_timestamps.is_empty() {
802 captured_after_timestamps.sort();
803
804 let mut captured_examples = pull_examples::fetch_captured_examples_after(
805 http_client.clone(),
806 &captured_after_timestamps,
807 max_rows_per_timestamp,
808 remaining_offset,
809 background_executor.clone(),
810 Some(MIN_CAPTURE_VERSION),
811 )
812 .await?;
813 examples.append(&mut captured_examples);
814 }
815
816 if !settled_after_timestamps.is_empty() {
817 settled_after_timestamps.sort();
818
819 let mut settled_examples = fetch_settled_examples_after(
820 http_client.clone(),
821 &settled_after_timestamps,
822 max_rows_per_timestamp,
823 remaining_offset,
824 background_executor.clone(),
825 Some(MIN_CAPTURE_VERSION),
826 )
827 .await?;
828 examples.append(&mut settled_examples);
829 }
830
831 if !rated_after_inputs.is_empty() {
832 rated_after_inputs.sort();
833
834 let mut rated_examples = pull_examples::fetch_rated_examples_after(
835 http_client,
836 &rated_after_inputs,
837 max_rows_per_timestamp,
838 remaining_offset,
839 background_executor,
840 Some(MIN_CAPTURE_VERSION),
841 )
842 .await?;
843 examples.append(&mut rated_examples);
844 }
845 }
846
847 crate::example::sort_examples_by_repo_and_rev(&mut examples);
848
849 if let Some(name_filter) = &args.name {
850 examples.retain(|example| example.spec.name.contains(name_filter));
851 }
852 if let Some(repo_filter) = &args.repo {
853 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
854 }
855
856 // Skip resume logic for --in-place since input and output are the same file,
857 // which would incorrectly treat all input examples as already processed.
858 if !args.in_place {
859 if let Some(path) = output_path
860 && let Some(command) = &args.command
861 {
862 resume_from_output(path, &mut examples, command);
863 }
864 }
865
866 if let Some(max_duplicates) = args.max_duplicates {
867 deduplicate_examples(&mut examples, max_duplicates);
868 }
869
870 if let Some(limit) = args.limit {
871 examples.truncate(limit);
872 }
873
874 let progress = Progress::global();
875 progress.set_total_examples(examples.len());
876 progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
877
878 Ok(examples)
879}
880
881fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
882 let mut hasher = collections::FxHasher::default();
883 spec.hash(&mut hasher);
884 hasher.finish()
885}
886
887fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>, command: &Command) {
888 let file = match File::open(path) {
889 Ok(f) => f,
890 Err(_) => return,
891 };
892
893 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
894
895 let reader = BufReader::new(file);
896 let mut kept_lines = Vec::new();
897 let mut kept_hashes = HashSet::default();
898
899 for line in reader.lines() {
900 let line = match line {
901 Ok(l) => l,
902 Err(_) => continue,
903 };
904
905 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
906 let hash = spec_hash(&output_example.spec);
907 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
908 let is_complete = match command {
909 Command::Qa(_) => output_example
910 .qa
911 .first()
912 .and_then(|q| q.as_ref())
913 .and_then(|q| q.confidence)
914 .is_some(),
915 Command::Repair(_) => output_example.predictions.iter().any(|p| {
916 p.provider == PredictionProvider::Repair && p.actual_patch.is_some()
917 }),
918 _ => true,
919 };
920 if is_complete {
921 kept_hashes.insert(hash);
922 kept_lines.push(line);
923 }
924 }
925 }
926 }
927
928 let total = examples.len();
929 let already_processed = kept_hashes.len();
930
931 eprintln!(
932 "Resuming: {}/{} examples already processed",
933 already_processed, total
934 );
935
936 let file = OpenOptions::new()
937 .write(true)
938 .truncate(true)
939 .open(path)
940 .expect("Failed to open output file for rewriting");
941 let mut writer = BufWriter::new(file);
942 for line in &kept_lines {
943 writeln!(writer, "{}", line).expect("Failed to write to output file");
944 }
945 writer.flush().expect("Failed to flush output file");
946
947 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
948}
949
950fn main() {
951 let args = EpArgs::parse();
952
953 if args.printenv {
954 ::util::shell_env::print_env();
955 return;
956 }
957
958 let output = args.output_path();
959
960 if args.markdown && output.is_none() {
961 eprintln!("--markdown requires -o to specify the output directory");
962 std::process::exit(1);
963 }
964
965 let command = match &args.command {
966 Some(cmd) => cmd.clone(),
967 None => {
968 EpArgs::command().print_help().unwrap();
969 return;
970 }
971 };
972
973 match &command {
974 Command::ImportBatch(import_args) => {
975 smol::block_on(async {
976 match import_args.provider {
977 BatchProvider::Anthropic => {
978 let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
979 .expect("Failed to create Anthropic client");
980 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
981 eprintln!("Error importing Anthropic batches: {:?}", e);
982 std::process::exit(1);
983 }
984 }
985 BatchProvider::Openai => {
986 let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
987 .expect("Failed to create OpenAI client");
988 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
989 eprintln!("Error importing OpenAI batches: {:?}", e);
990 std::process::exit(1);
991 }
992 }
993 }
994 println!(
995 "Successfully imported {} batch(es)",
996 import_args.batch_ids.len()
997 );
998 });
999 return;
1000 }
1001 Command::Clean => {
1002 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
1003 return;
1004 }
1005 Command::PrintZetaFormats => {
1006 use strum::IntoEnumIterator as _;
1007 for format in ZetaFormat::iter() {
1008 println!("{}", format.to_string().to_lowercase());
1009 }
1010 return;
1011 }
1012
1013 Command::Synthesize(synth_args) => {
1014 let output_dir = if let Some(output_dir) = args.output {
1015 output_dir
1016 } else {
1017 let default_output_dir = env::current_dir()
1018 .unwrap()
1019 .join("crates/edit_prediction_cli/evals-generated");
1020 if default_output_dir.parent().unwrap().exists() {
1021 std::fs::create_dir(&default_output_dir).ok();
1022 default_output_dir
1023 } else {
1024 panic!("output dir is required");
1025 }
1026 };
1027 let config = SynthesizeConfig {
1028 repo_urls: synth_args.repos.clone(),
1029 count: synth_args.count,
1030 max_commits: synth_args.max_commits,
1031 output_dir,
1032 fresh: synth_args.fresh,
1033 };
1034 smol::block_on(async {
1035 if let Err(e) = run_synthesize(config).await {
1036 eprintln!("Error: {:?}", e);
1037 std::process::exit(1);
1038 }
1039 });
1040 return;
1041 }
1042 Command::SplitCommit(split_commit_args) => {
1043 if let Err(error) = split_commit::run_split_commit(
1044 split_commit_args,
1045 &args.inputs,
1046 output.as_ref(),
1047 args.failed,
1048 ) {
1049 eprintln!("{error:#}");
1050 std::process::exit(1);
1051 }
1052 return;
1053 }
1054 Command::TruncatePatch(truncate_args) => {
1055 if let Err(error) =
1056 truncate_expected_patch::run_truncate_expected_patch(truncate_args, &args.inputs)
1057 {
1058 eprintln!("{error:#}");
1059 std::process::exit(1);
1060 }
1061 return;
1062 }
1063 Command::Split(split_args) => {
1064 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
1065 eprintln!("{error:#}");
1066 std::process::exit(1);
1067 }
1068 return;
1069 }
1070 Command::FilterLanguages(filter_args) => {
1071 if let Err(error) =
1072 run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
1073 {
1074 eprintln!("{error:#}");
1075 std::process::exit(1);
1076 }
1077 return;
1078 }
1079
1080 _ => {}
1081 }
1082
1083 let http_client = Arc::new(ReqwestClient::new());
1084 let app = gpui_platform::headless().with_http_client(http_client);
1085
1086 app.run(move |cx| {
1087 let app_state = Arc::new(headless::init(cx));
1088 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
1089
1090 cx.spawn(async move |cx| {
1091 let result = async {
1092 let examples = load_examples(
1093 app_state.client.http_client(),
1094 &args,
1095 output.as_ref(),
1096 cx.background_executor().clone(),
1097 )
1098 .await?;
1099
1100 match &command {
1101 Command::Predict(args) | Command::Score(args) => {
1102 predict::sync_batches(args.provider.as_ref()).await?;
1103 }
1104 Command::Eval(args) => {
1105 predict::sync_batches(args.predict.provider.as_ref()).await?;
1106 }
1107 Command::Qa(args) => {
1108 qa::sync_batches(args).await?;
1109 }
1110 Command::Repair(args) => {
1111 repair::sync_batches(args).await?;
1112 }
1113 _ => (),
1114 }
1115
1116 let failfast_on_single_example = examples.len() == 1;
1117
1118 // For --markdown mode, create the output directory if it doesn't exist
1119 if args.markdown {
1120 let dir = output.as_ref().expect("--markdown requires -o");
1121 if !dir.exists() {
1122 std::fs::create_dir_all(dir)
1123 .expect("Failed to create markdown output directory");
1124 }
1125 }
1126
1127 // Set up JSONL output writer (not used in markdown mode)
1128 let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
1129 let mut in_place_temp_path: Option<PathBuf> = None;
1130 if !args.markdown
1131 && let Some(output_path) = output.as_ref()
1132 {
1133 let write_path = if args.in_place {
1134 let temp = output_path.with_extension("jsonl.tmp");
1135 in_place_temp_path = Some(temp.clone());
1136 temp
1137 } else {
1138 output_path.clone()
1139 };
1140
1141 let file = OpenOptions::new()
1142 .create(true)
1143 .write(true)
1144 .truncate(args.in_place)
1145 .append(!args.in_place)
1146 .open(&write_path)
1147 .expect("Failed to open output file");
1148
1149 let mut writer = BufWriter::new(file);
1150 let (sender, mut receiver) = mpsc::unbounded::<String>();
1151 cx.background_spawn(async move {
1152 while let Some(line) = receiver.next().await {
1153 writeln!(writer, "{}", line).expect("Failed to write example");
1154 writer.flush().expect("Failed to flush output");
1155 }
1156 })
1157 .detach();
1158 output_sender = Some(sender);
1159 }
1160
1161 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
1162 let finished_examples = Mutex::new(Vec::new());
1163
1164 let mut tasks = Vec::new();
1165 for _ in 0..args.max_parallelism {
1166 tasks.push(async {
1167 loop {
1168 let Some(mut repo_examples) =
1169 grouped_examples.lock().unwrap().pop_front()
1170 else {
1171 break;
1172 };
1173 for example in &mut repo_examples {
1174 let example_progress =
1175 Progress::global().start_group(&example.spec.name);
1176
1177 let result = async {
1178 match &command {
1179 Command::Read(_) => {}
1180 Command::LoadProject => {
1181 run_load_project(
1182 example,
1183 app_state.clone(),
1184 &example_progress,
1185 cx.clone(),
1186 )
1187 .await?;
1188 }
1189 Command::Context => {
1190 run_context_retrieval(
1191 example,
1192 app_state.clone(),
1193 &example_progress,
1194 cx.clone(),
1195 )
1196 .await?;
1197 }
1198 Command::FormatPrompt(args) => {
1199 run_format_prompt(
1200 example,
1201 args,
1202 app_state.clone(),
1203 &example_progress,
1204 cx.clone(),
1205 )
1206 .await?;
1207 }
1208 Command::Predict(args) => {
1209 run_prediction(
1210 example,
1211 args,
1212 app_state.clone(),
1213 &example_progress,
1214 cx.clone(),
1215 )
1216 .await?;
1217 }
1218 Command::ParseOutput => {
1219 parse_output::run_parse_output(example)?;
1220 }
1221 Command::Distill => {
1222 run_distill(example).await?;
1223 }
1224 Command::Score(args) => {
1225 run_scoring(
1226 example,
1227 args,
1228 app_state.clone(),
1229 &example_progress,
1230 cx.clone(),
1231 )
1232 .await?;
1233 }
1234 Command::Eval(args) => {
1235 run_scoring(
1236 example,
1237 &args.predict,
1238 app_state.clone(),
1239 &example_progress,
1240 cx.clone(),
1241 )
1242 .await?;
1243 }
1244 Command::Qa(args) => {
1245 qa::run_qa(example, args, &example_progress).await?;
1246 }
1247 Command::Repair(args) => {
1248 repair::run_repair(example, args, &example_progress)
1249 .await?;
1250 }
1251 Command::Clean
1252 | Command::Synthesize(_)
1253 | Command::SplitCommit(_)
1254 | Command::Split(_)
1255 | Command::TruncatePatch(_)
1256 | Command::FilterLanguages(_)
1257 | Command::ImportBatch(_)
1258 | Command::PrintZetaFormats => {
1259 unreachable!()
1260 }
1261 }
1262 anyhow::Ok(())
1263 }
1264 .await;
1265
1266 let failed = if let Err(error) = result {
1267 handle_error(
1268 error,
1269 &args,
1270 &command,
1271 &app_state,
1272 failfast_on_single_example,
1273 &example,
1274 )
1275 .await;
1276 true
1277 } else {
1278 false
1279 };
1280
1281 let should_write = !failed || args.failed == FailedHandling::Keep;
1282 if should_write {
1283 if args.markdown {
1284 let markdown_dir =
1285 output.as_ref().expect("--markdown requires -o");
1286 let filename = format!("{}.md", example.spec.filename());
1287 let path = markdown_dir.join(&filename);
1288 let markdown = example.spec.to_markdown();
1289 std::fs::write(&path, &markdown)
1290 .expect("Failed to write markdown file");
1291 } else if let Some(ref mut sender) = output_sender.clone() {
1292 let line = serde_json::to_string(&example).unwrap();
1293 sender
1294 .send(line)
1295 .await
1296 .expect("Failed to send to output writer");
1297 } else if args.output.is_none()
1298 && !matches!(command, Command::Eval(_))
1299 {
1300 let line = serde_json::to_string(&example).unwrap();
1301 println!("{}", line);
1302 }
1303 }
1304 }
1305
1306 let project = repo_examples
1307 .iter()
1308 .find_map(|e| e.state.as_ref().map(|s| s.project.clone()));
1309
1310 if let Some(project) = project {
1311 let mut cx = cx.clone();
1312
1313 let shutdown_task: Task<()> =
1314 project.update(&mut cx, |project, cx| {
1315 let lsp_store = project.lsp_store();
1316 lsp_store.update(cx, |lsp_store, cx| {
1317 lsp_store.shutdown_all_language_servers(cx)
1318 })
1319 });
1320
1321 shutdown_task.await;
1322
1323 if let Some(ep_store) =
1324 cx.update(|cx| EditPredictionStore::try_global(cx))
1325 {
1326 ep_store.update(&mut cx, |store, _| {
1327 store.remove_project(&project);
1328 });
1329 }
1330 }
1331
1332 for example in &mut repo_examples {
1333 example.state.take();
1334 }
1335 finished_examples
1336 .lock()
1337 .unwrap()
1338 .extend_from_slice(&repo_examples);
1339 }
1340 });
1341 }
1342 futures::future::join_all(tasks).await;
1343
1344 Progress::global().finalize();
1345
1346 let is_markdown = args.markdown;
1347 let write_path = in_place_temp_path.as_ref().or(output.as_ref());
1348 match &command {
1349 Command::Predict(args) | Command::Score(args) => {
1350 predict::sync_batches(args.provider.as_ref()).await?;
1351 if args.wait {
1352 predict::wait_for_batches(args.provider.as_ref()).await?;
1353 let mut examples =
1354 std::mem::take(&mut *finished_examples.lock().unwrap());
1355 predict::reprocess_after_batch_wait(&mut examples, args).await?;
1356 rewrite_output(&examples, write_path, is_markdown)?;
1357 *finished_examples.lock().unwrap() = examples;
1358 }
1359 }
1360 Command::Eval(args) => {
1361 predict::sync_batches(args.predict.provider.as_ref()).await?;
1362 if args.predict.wait {
1363 predict::wait_for_batches(args.predict.provider.as_ref()).await?;
1364 let mut examples =
1365 std::mem::take(&mut *finished_examples.lock().unwrap());
1366 predict::reprocess_after_batch_wait(&mut examples, &args.predict)
1367 .await?;
1368 rewrite_output(&examples, write_path, is_markdown)?;
1369 *finished_examples.lock().unwrap() = examples;
1370 }
1371 }
1372 Command::Qa(args) => {
1373 qa::sync_batches(args).await?;
1374 }
1375 Command::Repair(args) => {
1376 repair::sync_batches(args).await?;
1377 if args.wait {
1378 repair::wait_for_batches(args).await?;
1379 let mut examples =
1380 std::mem::take(&mut *finished_examples.lock().unwrap());
1381 repair::reprocess_after_batch_wait(&mut examples, args).await?;
1382 rewrite_output(&examples, write_path, is_markdown)?;
1383 *finished_examples.lock().unwrap() = examples;
1384 }
1385 }
1386 _ => (),
1387 }
1388
1389 match &command {
1390 Command::Eval(args) => {
1391 let examples = finished_examples.lock().unwrap();
1392 score::print_report(&examples, args.verbose);
1393 if let Some(summary_path) = &args.summary_json {
1394 score::write_summary_json(&examples, summary_path)?;
1395 }
1396 }
1397 Command::Repair(args) => {
1398 let examples = finished_examples.lock().unwrap();
1399 repair::print_report(&examples, args.confidence_threshold);
1400 }
1401 _ => (),
1402 };
1403
1404 // For --in-place, atomically rename temp file to original
1405 if let Some(temp_path) = &in_place_temp_path {
1406 let final_path = output.as_ref().expect("in_place_temp_path requires output");
1407 std::fs::rename(temp_path, final_path)
1408 .expect("Failed to rename temp file to final output");
1409 }
1410
1411 anyhow::Ok(())
1412 }
1413 .await;
1414
1415 if let Err(e) = result {
1416 panic!("Fatal error: {:?}", e);
1417 }
1418
1419 let _ = cx.update(|cx| cx.quit());
1420 })
1421 .detach();
1422 });
1423}
1424
1425fn rewrite_output(
1426 examples: &[Example],
1427 output_path: Option<&PathBuf>,
1428 markdown: bool,
1429) -> anyhow::Result<()> {
1430 if markdown {
1431 let dir = output_path.context("--markdown requires -o")?;
1432 for example in examples {
1433 let filename = format!("{}.md", example.spec.filename());
1434 let path = dir.join(&filename);
1435 let markdown = example.spec.to_markdown();
1436 std::fs::write(&path, &markdown).context("Failed to write markdown file")?;
1437 }
1438 } else if let Some(path) = output_path {
1439 let file = OpenOptions::new()
1440 .create(true)
1441 .write(true)
1442 .truncate(true)
1443 .open(path)
1444 .context("Failed to open output file for rewriting")?;
1445 let mut writer = BufWriter::new(file);
1446 for example in examples {
1447 let line = serde_json::to_string(example)?;
1448 writeln!(writer, "{}", line)?;
1449 }
1450 writer.flush()?;
1451 } else {
1452 for example in examples {
1453 let line = serde_json::to_string(example)?;
1454 println!("{}", line);
1455 }
1456 }
1457 Ok(())
1458}
1459
1460async fn handle_error(
1461 error: anyhow::Error,
1462 args: &EpArgs,
1463 command: &Command,
1464 app_state: &Arc<headless::EpAppState>,
1465 failfast_on_single_example: bool,
1466 example: &Example,
1467) {
1468 Progress::global().increment_failed();
1469
1470 let msg;
1471 if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1472 let example_name = example.spec.filename();
1473
1474 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1475 app_state
1476 .fs
1477 .write(
1478 &failed_example_path,
1479 &serde_json::to_vec_pretty(&example).unwrap(),
1480 )
1481 .await
1482 .unwrap();
1483 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1484 app_state
1485 .fs
1486 .write(&err_path, format!("{error:?}").as_bytes())
1487 .await
1488 .unwrap();
1489
1490 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1491 let mut file = OpenOptions::new()
1492 .create(true)
1493 .append(true)
1494 .open(&failed_jsonl_path)
1495 .expect("Failed to open failed.jsonl");
1496 writeln!(file, "{}", serde_json::to_string(example).unwrap())
1497 .expect("Failed to write to failed.jsonl");
1498
1499 let cursor_path = match example.repo_name() {
1500 Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1501 Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1502 };
1503 msg = format!(
1504 indoc::indoc! {"
1505 While processing \"{}\":
1506
1507 \x1b[31m{:?}\x1b[0m
1508
1509 Example: \x1b[36m{}\x1b[0m
1510 Error file: \x1b[36m{}\x1b[0m
1511 Cursor file: \x1b[36m{}\x1b[0m
1512 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1513 "},
1514 example.spec.name,
1515 error,
1516 failed_example_path.display(),
1517 err_path.display(),
1518 cursor_path.display(),
1519 command,
1520 failed_example_path.display(),
1521 );
1522 } else {
1523 msg = format!(
1524 indoc::indoc! {"
1525 While processing \"{}\":
1526
1527 \x1b[31m{:?}\x1b[0m
1528 "},
1529 example.spec.name, error
1530 );
1531 }
1532
1533 if args.failfast || failfast_on_single_example {
1534 Progress::global().finalize();
1535 panic!("{}", msg);
1536 } else {
1537 log::error!("{}", msg);
1538 }
1539}