1mod anthropic_client;
2mod distill;
3mod example;
4mod filter_languages;
5mod format_prompt;
6mod git;
7mod headless;
8
9mod load_project;
10mod metrics;
11mod openai_client;
12mod parse_output;
13mod paths;
14mod predict;
15mod progress;
16mod prompt_assets;
17mod pull_examples;
18mod qa;
19mod reorder_patch;
20mod repair;
21mod retrieve_context;
22mod reversal_tracking;
23mod score;
24mod split_commit;
25mod split_dataset;
26
27mod synthesize;
28mod truncate_expected_patch;
29mod word_diff;
30use anyhow::Context as _;
31use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
32use collections::{HashMap, HashSet};
33use edit_prediction::EditPredictionStore;
34use futures::channel::mpsc;
35use futures::{SinkExt as _, StreamExt as _};
36use gaoya::minhash::{
37 MinHashIndex, MinHasher, MinHasher32, calculate_minhash_params, compute_minhash_similarity,
38};
39use gpui::{AppContext as _, BackgroundExecutor, Task};
40use zeta_prompt::ZetaFormat;
41
42use reqwest_client::ReqwestClient;
43use serde::{Deserialize, Deserializer, Serialize, Serializer};
44use std::env;
45use std::fmt::Display;
46use std::fs::{File, OpenOptions};
47use std::hash::{Hash, Hasher};
48use std::io::{BufRead, BufReader, BufWriter, Write};
49use std::str::FromStr;
50use std::sync::Mutex;
51use std::{path::PathBuf, sync::Arc};
52
53use crate::distill::run_distill;
54use crate::example::{Example, group_examples_by_repo, read_example_files};
55use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
56use crate::format_prompt::run_format_prompt;
57use crate::load_project::run_load_project;
58use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
59use crate::predict::run_prediction;
60use crate::progress::Progress;
61use crate::pull_examples::{fetch_settled_examples_after, parse_settled_after_input};
62use crate::retrieve_context::run_context_retrieval;
63use crate::score::run_scoring;
64use crate::split_commit::SplitCommitArgs;
65use crate::split_dataset::SplitArgs;
66use crate::synthesize::{SynthesizeConfig, run_synthesize};
67use crate::truncate_expected_patch::TruncatePatchArgs;
68
69#[derive(Parser, Debug)]
70#[command(name = "ep")]
71struct EpArgs {
72 #[arg(long, default_value_t = false)]
73 printenv: bool,
74 #[clap(long, default_value_t = 10, global = true)]
75 max_parallelism: usize,
76 /// The limit for the number of examples to process
77 /// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
78 #[clap(long, global = true)]
79 limit: Option<usize>,
80 #[clap(long, global = true)]
81 offset: Option<usize>,
82 /// Filter examples by name
83 #[clap(long, global = true)]
84 name: Option<String>,
85 /// Filter examples by repository
86 #[clap(long, global = true)]
87 repo: Option<String>,
88 /// Deduplicate by cursor position and keep at most this many examples per cluster
89 #[clap(long, global = true)]
90 max_duplicates: Option<usize>,
91 #[command(subcommand)]
92 command: Option<Command>,
93 /// Input file paths
94 #[clap(global = true)]
95 inputs: Vec<PathBuf>,
96 #[arg(long, short, global = true)]
97 output: Option<PathBuf>,
98 #[arg(long, short, global = true)]
99 in_place: bool,
100 #[arg(long, short, global = true)]
101 failfast: bool,
102 /// How to handle failed examples in output: keep them or skip them.
103 /// Failed examples are always logged to the run's failed directory.
104 #[arg(long, global = true, default_value = "keep")]
105 failed: FailedHandling,
106 /// Output as markdown files instead of JSONL. When set, -o specifies a directory
107 /// where one .md file per example will be written (named after each example).
108 #[arg(long, short, global = true)]
109 markdown: bool,
110}
111
112/// Controls whether failed examples are included in the main output.
113/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
114#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
115pub enum FailedHandling {
116 /// Include failed examples in the main output (default)
117 #[default]
118 Keep,
119 /// Exclude failed examples from the main output
120 Skip,
121 /// Skip writing files
122 SkipNoFiles,
123}
124
125const INPUTS_HELP: &str = r#"
126Inputs can be file paths or special specifiers:
127
128 path
129 Path to an example(s) file (.md, .json, or .jsonl)
130
131 captured-after:{timestamp}
132 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
133 These are examples captured via the "Capture Edit Prediction Example" action.
134
135 rejected-after:{timestamp}
136 Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
137 These are predictions that were shown to users but rejected (useful for DPO training).
138
139 settled-after:{timestamp}
140 Fetch settled stream examples from Snowflake after the given RFC3339 timestamp.
141 These are examples from the edit prediction settled stream.
142
143 rated-after:{timestamp}
144 Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
145 These are predictions that users explicitly rated as positive or negative via the
146 rate completions modal. Only zeta2 predictions are included.
147 - Positive ratings: output becomes expected_patches
148 - Negative ratings: output becomes rejected_patch
149
150 rated-positive-after:{timestamp}
151 Same as rated-after, but only fetches positively rated predictions.
152
153 rated-negative-after:{timestamp}
154 Same as rated-after, but only fetches negatively rated predictions.
155
156 Required environment variables to connect to Snowflake:
157 EP_SNOWFLAKE_API_KEY
158 EP_SNOWFLAKE_BASE_URL
159
160 Optional:
161 EP_SNOWFLAKE_ROLE
162
163Examples:
164
165 # Read examples from a file
166 ep read examples.jsonl -o output.jsonl
167
168 # Read captured examples after a timestamp
169 ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
170
171 # Read rejected predictions for DPO training
172 ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
173
174 # Read user-rated predictions
175 ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
176
177 # Read settled stream examples
178 ep read settled-after:2025-01-01T00:00:00Z -o settled.jsonl
179
180 # Read only positively rated predictions
181 ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
182
183 # Read only negatively rated predictions
184 ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
185
186 # Mix multiple input sources
187 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
188"#;
189
190#[derive(Subcommand, Debug, Clone)]
191enum Command {
192 /// Read examples from files or fetch from Snowflake, output as .jsonl
193 Read(ReadArgs),
194 /// Create git worktrees for each example and load file contents
195 LoadProject,
196 /// Retrieve context for input examples.
197 Context,
198 /// Generate a prompt string for a specific model
199 FormatPrompt(FormatPromptArgs),
200 /// Runs edit prediction
201 Predict(PredictArgs),
202 /// Parse model outputs (actual_output) into unified diffs (actual_patch).
203 /// Requires format-prompt to have been run first. Uses provider from prompt.
204 ParseOutput,
205 /// Computes a score based on actual and expected patches
206 Score(PredictArgs),
207 /// Prepares a distillation dataset by copying expected outputs to
208 /// predicted outputs and removing actual outputs and prompts.
209 Distill,
210 /// Print aggregated scores
211 Eval(EvalArgs),
212 /// Generate eval examples by analyzing git commits from a repository
213 Synthesize(SynthesizeArgs),
214 /// Remove git repositories and worktrees
215 Clean,
216 /// Generate an evaluation example by splitting a chronologically-ordered commit
217 SplitCommit(SplitCommitArgs),
218 /// Truncate expected patch by the given criteria
219 TruncatePatch(TruncatePatchArgs),
220 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
221 Split(SplitArgs),
222 /// Filter a JSONL dataset by programming language (based on cursor_path extension)
223 FilterLanguages(FilterLanguagesArgs),
224 /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
225 ImportBatch(ImportBatchArgs),
226 /// Assess the quality of predictions using LLM-as-a-judge
227 Qa(qa::QaArgs),
228 /// Repair predictions that received poor QA scores by generating improved predictions
229 Repair(repair::RepairArgs),
230 /// Print all valid zeta formats (lowercase, one per line)
231 PrintZetaFormats,
232}
233
234impl Display for Command {
235 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236 match self {
237 Command::Read(_) => write!(f, "read"),
238 Command::LoadProject => write!(f, "load-project"),
239 Command::Context => write!(f, "context"),
240 Command::FormatPrompt(args) => {
241 write!(f, "format-prompt --provider={}", args.provider)
242 }
243 Command::Predict(args) => match &args.provider {
244 Some(provider) => write!(f, "predict --provider={}", provider),
245 None => write!(f, "predict"),
246 },
247 Command::ParseOutput => write!(f, "parse-output"),
248 Command::Score(args) => match &args.provider {
249 Some(provider) => write!(f, "score --provider={}", provider),
250 None => write!(f, "score"),
251 },
252 Command::Distill => write!(f, "distill"),
253 Command::Eval(args) => match &args.predict.provider {
254 Some(provider) => write!(f, "eval --provider={}", provider),
255 None => write!(f, "eval"),
256 },
257 Command::Synthesize(args) => {
258 write!(f, "synthesize --repos {}", args.repos.join(" "))
259 }
260 Command::Clean => write!(f, "clean"),
261 Command::SplitCommit(_) => write!(f, "split-commit"),
262 Command::TruncatePatch(_) => write!(f, "truncate-patch"),
263 Command::Split(_) => write!(f, "split"),
264 Command::FilterLanguages(_) => write!(f, "filter-languages"),
265 Command::ImportBatch(args) => {
266 write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
267 }
268 Command::Qa(_) => {
269 write!(f, "qa")
270 }
271 Command::Repair(_) => {
272 write!(f, "repair")
273 }
274 Command::PrintZetaFormats => {
275 write!(f, "print-zeta-formats")
276 }
277 }
278 }
279}
280
281#[derive(Debug, Args, Clone)]
282#[command(after_help = INPUTS_HELP)]
283struct ReadArgs {}
284
285#[derive(Debug, Args, Clone)]
286struct FormatPromptArgs {
287 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
288 provider: PredictionProvider,
289}
290
291#[derive(Debug, Args, Clone)]
292struct PredictArgs {
293 #[clap(long, short('p'))]
294 provider: Option<PredictionProvider>,
295 #[clap(long, default_value_t = 1)]
296 repetitions: usize,
297 /// Only use cached responses, don't queue new requests for batching
298 #[clap(long)]
299 cache_only: bool,
300 /// Wait for all batches to complete before exiting (only applies to batched providers like teacher)
301 #[clap(long)]
302 wait: bool,
303}
304
305#[derive(Debug, Args, Clone)]
306struct EvalArgs {
307 #[clap(flatten)]
308 predict: PredictArgs,
309 /// Path to write summary scores as JSON
310 #[clap(long)]
311 summary_json: Option<PathBuf>,
312 /// Print all individual example lines (default: up to 20)
313 #[clap(long)]
314 verbose: bool,
315}
316
317#[derive(Clone, Copy, Default, Debug, PartialEq, Eq, Hash)]
318pub enum TeacherBackend {
319 Sonnet46,
320 #[default]
321 Sonnet45,
322 Gpt52,
323}
324
325impl std::fmt::Display for TeacherBackend {
326 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
327 match self {
328 TeacherBackend::Sonnet46 => write!(f, "sonnet46"),
329 TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
330 TeacherBackend::Gpt52 => write!(f, "gpt52"),
331 }
332 }
333}
334
335impl std::str::FromStr for TeacherBackend {
336 type Err = anyhow::Error;
337
338 fn from_str(s: &str) -> Result<Self, Self::Err> {
339 match s.to_lowercase().as_str() {
340 "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
341 "sonnet46" => Ok(TeacherBackend::Sonnet46),
342 "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
343 "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
344 _ => anyhow::bail!(
345 "unknown teacher backend `{s}`. Valid options: sonnet45, sonnet46, gpt52"
346 ),
347 }
348 }
349}
350
351impl TeacherBackend {
352 pub fn model_name(&self) -> &'static str {
353 match self {
354 TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
355 TeacherBackend::Sonnet46 => "claude-sonnet-4-6",
356 TeacherBackend::Gpt52 => "gpt-5.2",
357 }
358 }
359}
360
361#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
362enum PredictionProvider {
363 Mercury,
364 Zeta1,
365 Zeta2(ZetaFormat),
366 Baseten(ZetaFormat),
367 Teacher(TeacherBackend, ZetaFormat),
368 TeacherMultiRegion(TeacherBackend),
369 TeacherNonBatching(TeacherBackend, ZetaFormat),
370 TeacherMultiRegionNonBatching(TeacherBackend),
371 Repair,
372}
373
374impl Default for PredictionProvider {
375 fn default() -> Self {
376 PredictionProvider::Zeta2(ZetaFormat::default())
377 }
378}
379
380impl std::fmt::Display for PredictionProvider {
381 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
382 match self {
383 PredictionProvider::Mercury => write!(f, "mercury"),
384 PredictionProvider::Zeta1 => write!(f, "zeta1"),
385 PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"),
386 PredictionProvider::Baseten(format) => write!(f, "baseten:{format}"),
387 PredictionProvider::Teacher(backend, format) => {
388 write!(f, "teacher:{backend}:{format:?}")
389 }
390 PredictionProvider::TeacherMultiRegion(backend) => {
391 write!(f, "teacher-multi-region:{backend}")
392 }
393 PredictionProvider::TeacherNonBatching(backend, format) => {
394 write!(f, "teacher-non-batching:{backend}:{format:?}")
395 }
396 PredictionProvider::TeacherMultiRegionNonBatching(backend) => {
397 write!(f, "teacher-multi-region-non-batching:{backend}")
398 }
399 PredictionProvider::Repair => write!(f, "repair"),
400 }
401 }
402}
403
404impl std::str::FromStr for PredictionProvider {
405 type Err = anyhow::Error;
406
407 fn from_str(s: &str) -> Result<Self, Self::Err> {
408 let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
409
410 let provider_lower = provider.to_lowercase();
411 match provider_lower.as_str() {
412 "mercury" => Ok(PredictionProvider::Mercury),
413 "zeta1" => Ok(PredictionProvider::Zeta1),
414 "zeta2" => {
415 let format = arg.map(ZetaFormat::parse).transpose()?.unwrap_or_default();
416 Ok(PredictionProvider::Zeta2(format))
417 }
418 "teacher" => {
419 let (backend, format) = parse_teacher_args(arg)?;
420 Ok(PredictionProvider::Teacher(backend, format))
421 }
422 "teacher-non-batching" | "teacher_non_batching" => {
423 let (backend, format) = parse_teacher_args(arg)?;
424 Ok(PredictionProvider::TeacherNonBatching(backend, format))
425 }
426 "teacher-multi-region" | "teacher_multi_region" => {
427 let backend = arg
428 .map(|a| a.parse())
429 .transpose()?
430 .unwrap_or(TeacherBackend::default());
431 Ok(PredictionProvider::TeacherMultiRegion(backend))
432 }
433 "teacher-multi-region-non-batching" | "teacher_multi_region_non_batching" => {
434 let backend = arg
435 .map(|a| a.parse())
436 .transpose()?
437 .unwrap_or(TeacherBackend::default());
438 Ok(PredictionProvider::TeacherMultiRegionNonBatching(backend))
439 }
440 "repair" => Ok(PredictionProvider::Repair),
441 "baseten" => {
442 let format = arg
443 .map(ZetaFormat::parse)
444 .transpose()?
445 .unwrap_or(ZetaFormat::default());
446 Ok(PredictionProvider::Baseten(format))
447 }
448 _ => {
449 anyhow::bail!(
450 "unknown provider `{provider}`. Valid options: mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-multi-region, teacher-multi-region:<backend>, teacher-non-batching, teacher-multi-region-non-batching, repair\n\
451 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
452 For teacher providers, you can specify a backend like `teacher:sonnet46`, `teacher-multi-region:sonnet46`, `teacher-multi-region-non-batching:sonnet46`, or `teacher:gpt52`.\n\
453 Available zeta versions:\n{}",
454 ZetaFormat::options_as_string()
455 )
456 }
457 }
458 }
459}
460
461fn parse_teacher_args(arg: Option<&str>) -> Result<(TeacherBackend, ZetaFormat), anyhow::Error> {
462 let mut backend = TeacherBackend::default();
463 let mut format = ZetaFormat::default();
464
465 for arg in arg.unwrap_or_default().split(':') {
466 if arg.is_empty() {
467 continue;
468 }
469
470 if let Ok(parsed_backend) = TeacherBackend::from_str(arg) {
471 backend = parsed_backend;
472 } else if let Ok(parsed_format) = ZetaFormat::parse(arg) {
473 format = parsed_format;
474 } else {
475 anyhow::bail!("unknown teacher backend or zeta format `{arg}`");
476 }
477 }
478
479 Ok((backend, format))
480}
481
482impl Serialize for PredictionProvider {
483 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
484 where
485 S: Serializer,
486 {
487 serializer.serialize_str(&self.to_string())
488 }
489}
490
491impl<'de> Deserialize<'de> for PredictionProvider {
492 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
493 where
494 D: Deserializer<'de>,
495 {
496 let s = String::deserialize(deserializer)?;
497 s.parse().map_err(serde::de::Error::custom)
498 }
499}
500
501#[derive(Debug, Args, Clone)]
502struct SynthesizeArgs {
503 /// Repository URLs (git@github.com:owner/repo or https://...)
504 #[clap(long, required = true, num_args = 1..)]
505 repos: Vec<String>,
506
507 /// Number of examples to generate per repository
508 #[clap(long, default_value_t = 5)]
509 count: usize,
510
511 /// Maximum commits to scan per repository before giving up
512 #[clap(long, default_value_t = 100)]
513 max_commits: usize,
514
515 /// Ignore state file and reprocess all commits
516 #[clap(long)]
517 fresh: bool,
518}
519
520#[derive(Debug, Args, Clone)]
521struct ImportBatchArgs {
522 /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
523 #[clap(long, required = true, num_args = 1..)]
524 batch_ids: Vec<String>,
525 /// Which provider's batches to import (anthropic or openai)
526 #[clap(long, default_value = "anthropic")]
527 provider: BatchProvider,
528}
529
530#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
531enum BatchProvider {
532 Anthropic,
533 Openai,
534}
535
536#[cfg(test)]
537mod tests {
538 use super::*;
539
540 #[test]
541 fn prediction_provider_multi_region_non_batched_round_trips_to_primary_spelling() {
542 let provider: PredictionProvider = "teacher-multi-region-non-batching:sonnet46"
543 .parse()
544 .unwrap();
545 assert_eq!(
546 provider,
547 PredictionProvider::TeacherMultiRegionNonBatching(TeacherBackend::Sonnet46)
548 );
549 assert_eq!(
550 provider.to_string(),
551 "teacher-multi-region-non-batching:sonnet46"
552 );
553 }
554
555 #[test]
556 fn prediction_provider_multi_region_non_batched_alias_round_trips_to_primary_spelling() {
557 let provider: PredictionProvider =
558 "teacher_multi_region_non_batching:gpt52".parse().unwrap();
559 assert_eq!(
560 provider,
561 PredictionProvider::TeacherMultiRegionNonBatching(TeacherBackend::Gpt52)
562 );
563 assert_eq!(
564 provider.to_string(),
565 "teacher-multi-region-non-batching:gpt52"
566 );
567 }
568}
569
570impl EpArgs {
571 fn output_path(&self) -> Option<PathBuf> {
572 if self.in_place {
573 if self.inputs.len() == 1 {
574 self.inputs.first().cloned()
575 } else {
576 panic!("--in-place requires exactly one input file")
577 }
578 } else {
579 self.output.clone()
580 }
581 }
582}
583
584/// Minimum Zed version required for Snowflake queries.
585/// This version introduced the current request schema with predicted edits in the edit
586/// history, and open source repos distinguished.
587const MIN_CAPTURE_VERSION: pull_examples::MinCaptureVersion = pull_examples::MinCaptureVersion {
588 minor: 224,
589 patch: 1,
590};
591
592fn deduplicate_examples(examples: &mut Vec<Example>, max_per_cluster: usize) {
593 let total_before_exact = examples.len();
594 let mut seen_positions = HashSet::default();
595 examples.retain(|example| seen_positions.insert(example.spec.cursor_position.clone()));
596 log::info!(
597 "exact duplicate filter: {total_before_exact} examples → {} examples ({} removed)",
598 examples.len(),
599 total_before_exact - examples.len(),
600 );
601
602 const JACCARD_THRESHOLD: f64 = 0.5;
603 const NUM_HASHES: usize = 128;
604 const TOKEN_NGRAM_SIZE: usize = 5;
605
606 let (num_bands, band_width) = calculate_minhash_params(JACCARD_THRESHOLD, NUM_HASHES);
607 let num_hashes = num_bands * band_width;
608 let minhasher = MinHasher32::new(num_hashes);
609 let mut index: MinHashIndex<u32, usize> =
610 MinHashIndex::new(num_bands, band_width, JACCARD_THRESHOLD);
611
612 let signatures: Vec<Vec<u32>> = examples
613 .iter()
614 .map(|example| {
615 let shingles = code_token_ngrams(&example.spec.cursor_position, TOKEN_NGRAM_SIZE);
616 minhasher.create_signature(shingles.iter())
617 })
618 .collect();
619
620 for (id, signature) in signatures.iter().enumerate() {
621 index.insert(id, signature.clone());
622 }
623
624 // Build clusters via union-find on LSH candidate pairs.
625 let mut parent: Vec<usize> = (0..examples.len()).collect();
626
627 fn find(parent: &mut Vec<usize>, mut x: usize) -> usize {
628 while parent[x] != x {
629 parent[x] = parent[parent[x]];
630 x = parent[x];
631 }
632 x
633 }
634
635 for (id, signature) in signatures.iter().enumerate() {
636 for candidate in index.query_owned(signature) {
637 let (a, b) = (find(&mut parent, id), find(&mut parent, candidate));
638 if a != b {
639 parent[a] = b;
640 }
641 }
642 }
643
644 let mut clusters: HashMap<usize, Vec<usize>> = HashMap::default();
645 for id in 0..examples.len() {
646 clusters.entry(find(&mut parent, id)).or_default().push(id);
647 }
648
649 let mut keep: HashSet<usize> = HashSet::default();
650 for members in clusters.values() {
651 let selected = greedy_max_min_diverse(members, &signatures, max_per_cluster);
652 keep.extend(selected);
653 }
654
655 let total = examples.len();
656 let mut kept_indices: Vec<usize> = keep.into_iter().collect();
657 kept_indices.sort();
658
659 let mut retained = Vec::with_capacity(kept_indices.len());
660 for index in kept_indices.into_iter().rev() {
661 retained.push(examples.swap_remove(index));
662 }
663 retained.reverse();
664
665 *examples = retained;
666 log::info!(
667 "near-duplicate filter: {total} examples → {} examples ({} removed)",
668 examples.len(),
669 total - examples.len(),
670 );
671}
672
673fn greedy_max_min_diverse(members: &[usize], signatures: &[Vec<u32>], k: usize) -> Vec<usize> {
674 if members.len() <= k {
675 return members.to_vec();
676 }
677
678 let mut selected = vec![members[0]];
679 let mut min_dist: HashMap<usize, f64> = HashMap::default();
680 for &member in &members[1..] {
681 let dist = 1.0 - compute_minhash_similarity(&signatures[selected[0]], &signatures[member]);
682 min_dist.insert(member, dist);
683 }
684
685 while selected.len() < k {
686 let &best = min_dist
687 .iter()
688 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
689 .map(|(id, _)| id)
690 .expect("min_dist should not be empty when selected.len() < k");
691 selected.push(best);
692 min_dist.remove(&best);
693
694 let best_sig = &signatures[best];
695 for (member, current_min) in min_dist.iter_mut() {
696 let dist = 1.0 - compute_minhash_similarity(best_sig, &signatures[*member]);
697 if dist < *current_min {
698 *current_min = dist;
699 }
700 }
701 }
702
703 selected
704}
705
706fn code_token_ngrams(code: &str, ngram_size: usize) -> Vec<String> {
707 let tokens: Vec<&str> = word_diff::tokenize(code)
708 .into_iter()
709 .filter(|t| !t.trim().is_empty())
710 .collect();
711
712 if tokens.len() < ngram_size {
713 return vec![tokens.join("\0")];
714 }
715
716 tokens
717 .windows(ngram_size)
718 .map(|window| window.join("\0"))
719 .collect()
720}
721
722async fn load_examples(
723 http_client: Arc<dyn http_client::HttpClient>,
724 args: &EpArgs,
725 output_path: Option<&PathBuf>,
726 background_executor: BackgroundExecutor,
727) -> anyhow::Result<Vec<Example>> {
728 let mut captured_after_timestamps = Vec::new();
729 let mut rejected_after_timestamps = Vec::new();
730 let mut requested_after_timestamps = Vec::new();
731 let mut settled_after_timestamps = Vec::new();
732 let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
733 Vec::new();
734 let mut file_inputs = Vec::new();
735
736 for input in &args.inputs {
737 let input_string = input.to_string_lossy();
738 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
739 captured_after_timestamps.push(timestamp.to_string());
740 } else if let Some(timestamp) =
741 pull_examples::parse_rejected_after_input(input_string.as_ref())
742 {
743 rejected_after_timestamps.push(timestamp.to_string());
744 } else if let Some(timestamp) =
745 pull_examples::parse_requested_after_input(input_string.as_ref())
746 {
747 requested_after_timestamps.push(timestamp.to_string());
748 } else if let Some(timestamp) = parse_settled_after_input(input_string.as_ref()) {
749 settled_after_timestamps.push(timestamp.to_string());
750 } else if let Some((timestamp, rating_filter)) =
751 pull_examples::parse_rated_after_input(input_string.as_ref())
752 {
753 rated_after_inputs.push((timestamp.to_string(), rating_filter));
754 } else {
755 file_inputs.push(input.clone());
756 }
757 }
758
759 let mut examples = read_example_files(&file_inputs);
760
761 // Apply offset to file examples first, then pass remaining offset to Snowflake.
762 let file_example_count = examples.len();
763 let remaining_offset = if let Some(offset) = args.offset {
764 if offset >= file_example_count {
765 examples.clear();
766 offset - file_example_count
767 } else {
768 examples.splice(0..offset, []);
769 0
770 }
771 } else {
772 0
773 };
774
775 Progress::global().set_total_examples(examples.len());
776
777 let remaining_limit_for_snowflake =
778 args.limit.map(|limit| limit.saturating_sub(examples.len()));
779
780 if let Some(0) = remaining_limit_for_snowflake {
781 log::info!(
782 "skipping Snowflake inputs because --limit is already satisfied by example files"
783 );
784 } else {
785 let max_rows_per_timestamp = remaining_limit_for_snowflake;
786
787 if !rejected_after_timestamps.is_empty() {
788 rejected_after_timestamps.sort();
789
790 let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
791 http_client.clone(),
792 &rejected_after_timestamps,
793 max_rows_per_timestamp,
794 remaining_offset,
795 background_executor.clone(),
796 Some(MIN_CAPTURE_VERSION),
797 )
798 .await?;
799 examples.append(&mut rejected_examples);
800 }
801
802 if !requested_after_timestamps.is_empty() {
803 requested_after_timestamps.sort();
804
805 let mut requested_examples = pull_examples::fetch_requested_examples_after(
806 http_client.clone(),
807 &requested_after_timestamps,
808 max_rows_per_timestamp,
809 remaining_offset,
810 background_executor.clone(),
811 Some(MIN_CAPTURE_VERSION),
812 )
813 .await?;
814 examples.append(&mut requested_examples);
815 }
816
817 if !captured_after_timestamps.is_empty() {
818 captured_after_timestamps.sort();
819
820 let mut captured_examples = pull_examples::fetch_captured_examples_after(
821 http_client.clone(),
822 &captured_after_timestamps,
823 max_rows_per_timestamp,
824 remaining_offset,
825 background_executor.clone(),
826 Some(MIN_CAPTURE_VERSION),
827 )
828 .await?;
829 examples.append(&mut captured_examples);
830 }
831
832 if !settled_after_timestamps.is_empty() {
833 settled_after_timestamps.sort();
834
835 let mut settled_examples = fetch_settled_examples_after(
836 http_client.clone(),
837 &settled_after_timestamps,
838 max_rows_per_timestamp,
839 remaining_offset,
840 background_executor.clone(),
841 Some(MIN_CAPTURE_VERSION),
842 )
843 .await?;
844 examples.append(&mut settled_examples);
845 }
846
847 if !rated_after_inputs.is_empty() {
848 rated_after_inputs.sort();
849
850 let mut rated_examples = pull_examples::fetch_rated_examples_after(
851 http_client,
852 &rated_after_inputs,
853 max_rows_per_timestamp,
854 remaining_offset,
855 background_executor,
856 Some(MIN_CAPTURE_VERSION),
857 )
858 .await?;
859 examples.append(&mut rated_examples);
860 }
861 }
862
863 crate::example::sort_examples_by_repo_and_rev(&mut examples);
864
865 if let Some(name_filter) = &args.name {
866 examples.retain(|example| example.spec.name.contains(name_filter));
867 }
868 if let Some(repo_filter) = &args.repo {
869 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
870 }
871
872 // Skip resume logic for --in-place since input and output are the same file,
873 // which would incorrectly treat all input examples as already processed.
874 if !args.in_place {
875 if let Some(path) = output_path
876 && let Some(command) = &args.command
877 {
878 resume_from_output(path, &mut examples, command);
879 }
880 }
881
882 if let Some(max_duplicates) = args.max_duplicates {
883 deduplicate_examples(&mut examples, max_duplicates);
884 }
885
886 if let Some(limit) = args.limit {
887 examples.truncate(limit);
888 }
889
890 let progress = Progress::global();
891 progress.set_total_examples(examples.len());
892 progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
893
894 Ok(examples)
895}
896
897fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
898 let mut hasher = collections::FxHasher::default();
899 spec.hash(&mut hasher);
900 hasher.finish()
901}
902
903fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>, command: &Command) {
904 let file = match File::open(path) {
905 Ok(f) => f,
906 Err(_) => return,
907 };
908
909 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
910
911 let reader = BufReader::new(file);
912 let mut kept_lines = Vec::new();
913 let mut kept_hashes = HashSet::default();
914
915 for line in reader.lines() {
916 let line = match line {
917 Ok(l) => l,
918 Err(_) => continue,
919 };
920
921 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
922 let hash = spec_hash(&output_example.spec);
923 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
924 let is_complete = match command {
925 Command::Qa(_) => output_example
926 .qa
927 .first()
928 .and_then(|q| q.as_ref())
929 .and_then(|q| q.confidence)
930 .is_some(),
931 Command::Repair(_) => output_example.predictions.iter().any(|p| {
932 p.provider == PredictionProvider::Repair && p.actual_patch.is_some()
933 }),
934 _ => true,
935 };
936 if is_complete {
937 kept_hashes.insert(hash);
938 kept_lines.push(line);
939 }
940 }
941 }
942 }
943
944 let total = examples.len();
945 let already_processed = kept_hashes.len();
946
947 eprintln!(
948 "Resuming: {}/{} examples already processed",
949 already_processed, total
950 );
951
952 let file = OpenOptions::new()
953 .write(true)
954 .truncate(true)
955 .open(path)
956 .expect("Failed to open output file for rewriting");
957 let mut writer = BufWriter::new(file);
958 for line in &kept_lines {
959 writeln!(writer, "{}", line).expect("Failed to write to output file");
960 }
961 writer.flush().expect("Failed to flush output file");
962
963 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
964}
965
966fn main() {
967 let args = EpArgs::parse();
968
969 if args.printenv {
970 ::util::shell_env::print_env();
971 return;
972 }
973
974 let output = args.output_path();
975
976 if args.markdown && output.is_none() {
977 eprintln!("--markdown requires -o to specify the output directory");
978 std::process::exit(1);
979 }
980
981 let command = match &args.command {
982 Some(cmd) => cmd.clone(),
983 None => {
984 EpArgs::command().print_help().unwrap();
985 return;
986 }
987 };
988
989 match &command {
990 Command::ImportBatch(import_args) => {
991 smol::block_on(async {
992 match import_args.provider {
993 BatchProvider::Anthropic => {
994 let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
995 .expect("Failed to create Anthropic client");
996 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
997 eprintln!("Error importing Anthropic batches: {:?}", e);
998 std::process::exit(1);
999 }
1000 }
1001 BatchProvider::Openai => {
1002 let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
1003 .expect("Failed to create OpenAI client");
1004 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
1005 eprintln!("Error importing OpenAI batches: {:?}", e);
1006 std::process::exit(1);
1007 }
1008 }
1009 }
1010 println!(
1011 "Successfully imported {} batch(es)",
1012 import_args.batch_ids.len()
1013 );
1014 });
1015 return;
1016 }
1017 Command::Clean => {
1018 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
1019 return;
1020 }
1021 Command::PrintZetaFormats => {
1022 use strum::IntoEnumIterator as _;
1023 for format in ZetaFormat::iter() {
1024 println!("{}", format.to_string().to_lowercase());
1025 }
1026 return;
1027 }
1028
1029 Command::Synthesize(synth_args) => {
1030 let output_dir = if let Some(output_dir) = args.output {
1031 output_dir
1032 } else {
1033 let default_output_dir = env::current_dir()
1034 .unwrap()
1035 .join("crates/edit_prediction_cli/evals-generated");
1036 if default_output_dir.parent().unwrap().exists() {
1037 std::fs::create_dir(&default_output_dir).ok();
1038 default_output_dir
1039 } else {
1040 panic!("output dir is required");
1041 }
1042 };
1043 let config = SynthesizeConfig {
1044 repo_urls: synth_args.repos.clone(),
1045 count: synth_args.count,
1046 max_commits: synth_args.max_commits,
1047 output_dir,
1048 fresh: synth_args.fresh,
1049 };
1050 smol::block_on(async {
1051 if let Err(e) = run_synthesize(config).await {
1052 eprintln!("Error: {:?}", e);
1053 std::process::exit(1);
1054 }
1055 });
1056 return;
1057 }
1058 Command::SplitCommit(split_commit_args) => {
1059 if let Err(error) = split_commit::run_split_commit(
1060 split_commit_args,
1061 &args.inputs,
1062 output.as_ref(),
1063 args.failed,
1064 ) {
1065 eprintln!("{error:#}");
1066 std::process::exit(1);
1067 }
1068 return;
1069 }
1070 Command::TruncatePatch(truncate_args) => {
1071 if let Err(error) =
1072 truncate_expected_patch::run_truncate_expected_patch(truncate_args, &args.inputs)
1073 {
1074 eprintln!("{error:#}");
1075 std::process::exit(1);
1076 }
1077 return;
1078 }
1079 Command::Split(split_args) => {
1080 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
1081 eprintln!("{error:#}");
1082 std::process::exit(1);
1083 }
1084 return;
1085 }
1086 Command::FilterLanguages(filter_args) => {
1087 if let Err(error) =
1088 run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
1089 {
1090 eprintln!("{error:#}");
1091 std::process::exit(1);
1092 }
1093 return;
1094 }
1095
1096 _ => {}
1097 }
1098
1099 let http_client = Arc::new(ReqwestClient::new());
1100 let app = gpui_platform::headless().with_http_client(http_client);
1101
1102 app.run(move |cx| {
1103 let app_state = Arc::new(headless::init(cx));
1104 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
1105
1106 cx.spawn(async move |cx| {
1107 let result = async {
1108 let examples = load_examples(
1109 app_state.client.http_client(),
1110 &args,
1111 output.as_ref(),
1112 cx.background_executor().clone(),
1113 )
1114 .await?;
1115
1116 match &command {
1117 Command::Predict(args) | Command::Score(args) => {
1118 predict::sync_batches(args.provider.as_ref()).await?;
1119 }
1120 Command::Eval(args) => {
1121 predict::sync_batches(args.predict.provider.as_ref()).await?;
1122 }
1123 Command::Qa(args) => {
1124 qa::sync_batches(args).await?;
1125 }
1126 Command::Repair(args) => {
1127 repair::sync_batches(args).await?;
1128 }
1129 _ => (),
1130 }
1131
1132 let failfast_on_single_example = examples.len() == 1;
1133
1134 // For --markdown mode, create the output directory if it doesn't exist
1135 if args.markdown {
1136 let dir = output.as_ref().expect("--markdown requires -o");
1137 if !dir.exists() {
1138 std::fs::create_dir_all(dir)
1139 .expect("Failed to create markdown output directory");
1140 }
1141 }
1142
1143 // Set up JSONL output writer (not used in markdown mode)
1144 let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
1145 let mut in_place_temp_path: Option<PathBuf> = None;
1146 if !args.markdown
1147 && let Some(output_path) = output.as_ref()
1148 {
1149 let write_path = if args.in_place {
1150 let temp = output_path.with_extension("jsonl.tmp");
1151 in_place_temp_path = Some(temp.clone());
1152 temp
1153 } else {
1154 output_path.clone()
1155 };
1156
1157 let file = OpenOptions::new()
1158 .create(true)
1159 .write(true)
1160 .truncate(args.in_place)
1161 .append(!args.in_place)
1162 .open(&write_path)
1163 .expect("Failed to open output file");
1164
1165 let mut writer = BufWriter::new(file);
1166 let (sender, mut receiver) = mpsc::unbounded::<String>();
1167 cx.background_spawn(async move {
1168 while let Some(line) = receiver.next().await {
1169 writeln!(writer, "{}", line).expect("Failed to write example");
1170 writer.flush().expect("Failed to flush output");
1171 }
1172 })
1173 .detach();
1174 output_sender = Some(sender);
1175 }
1176
1177 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
1178 let finished_examples = Mutex::new(Vec::new());
1179
1180 let mut tasks = Vec::new();
1181 for _ in 0..args.max_parallelism {
1182 tasks.push(async {
1183 loop {
1184 let Some(mut repo_examples) =
1185 grouped_examples.lock().unwrap().pop_front()
1186 else {
1187 break;
1188 };
1189 for example in &mut repo_examples {
1190 let example_progress =
1191 Progress::global().start_group(&example.spec.name);
1192
1193 let result = async {
1194 match &command {
1195 Command::Read(_) => {}
1196 Command::LoadProject => {
1197 run_load_project(
1198 example,
1199 app_state.clone(),
1200 &example_progress,
1201 cx.clone(),
1202 )
1203 .await?;
1204 }
1205 Command::Context => {
1206 run_context_retrieval(
1207 example,
1208 app_state.clone(),
1209 &example_progress,
1210 cx.clone(),
1211 )
1212 .await?;
1213 }
1214 Command::FormatPrompt(args) => {
1215 run_format_prompt(
1216 example,
1217 args,
1218 app_state.clone(),
1219 &example_progress,
1220 cx.clone(),
1221 )
1222 .await?;
1223 }
1224 Command::Predict(args) => {
1225 run_prediction(
1226 example,
1227 args,
1228 app_state.clone(),
1229 &example_progress,
1230 cx.clone(),
1231 )
1232 .await?;
1233 }
1234 Command::ParseOutput => {
1235 parse_output::run_parse_output(example)?;
1236 }
1237 Command::Distill => {
1238 run_distill(example).await?;
1239 }
1240 Command::Score(args) => {
1241 run_scoring(
1242 example,
1243 args,
1244 app_state.clone(),
1245 &example_progress,
1246 cx.clone(),
1247 )
1248 .await?;
1249 }
1250 Command::Eval(args) => {
1251 run_scoring(
1252 example,
1253 &args.predict,
1254 app_state.clone(),
1255 &example_progress,
1256 cx.clone(),
1257 )
1258 .await?;
1259 }
1260 Command::Qa(args) => {
1261 qa::run_qa(example, args, &example_progress).await?;
1262 }
1263 Command::Repair(args) => {
1264 repair::run_repair(example, args, &example_progress)
1265 .await?;
1266 }
1267 Command::Clean
1268 | Command::Synthesize(_)
1269 | Command::SplitCommit(_)
1270 | Command::Split(_)
1271 | Command::TruncatePatch(_)
1272 | Command::FilterLanguages(_)
1273 | Command::ImportBatch(_)
1274 | Command::PrintZetaFormats => {
1275 unreachable!()
1276 }
1277 }
1278 anyhow::Ok(())
1279 }
1280 .await;
1281
1282 let failed = if let Err(error) = result {
1283 handle_error(
1284 error,
1285 &args,
1286 &command,
1287 &app_state,
1288 failfast_on_single_example,
1289 &example,
1290 )
1291 .await;
1292 true
1293 } else {
1294 false
1295 };
1296
1297 let should_write = !failed || args.failed == FailedHandling::Keep;
1298 if should_write {
1299 if args.markdown {
1300 let markdown_dir =
1301 output.as_ref().expect("--markdown requires -o");
1302 let filename = format!("{}.md", example.spec.filename());
1303 let path = markdown_dir.join(&filename);
1304 let markdown = example.spec.to_markdown();
1305 std::fs::write(&path, &markdown)
1306 .expect("Failed to write markdown file");
1307 } else if let Some(ref mut sender) = output_sender.clone() {
1308 let line = serde_json::to_string(&example).unwrap();
1309 sender
1310 .send(line)
1311 .await
1312 .expect("Failed to send to output writer");
1313 } else if args.output.is_none()
1314 && !matches!(command, Command::Eval(_))
1315 {
1316 let line = serde_json::to_string(&example).unwrap();
1317 println!("{}", line);
1318 }
1319 }
1320 }
1321
1322 let project = repo_examples
1323 .iter()
1324 .find_map(|e| e.state.as_ref().map(|s| s.project.clone()));
1325
1326 if let Some(project) = project {
1327 let mut cx = cx.clone();
1328
1329 let shutdown_task: Task<()> =
1330 project.update(&mut cx, |project, cx| {
1331 let lsp_store = project.lsp_store();
1332 lsp_store.update(cx, |lsp_store, cx| {
1333 lsp_store.shutdown_all_language_servers(cx)
1334 })
1335 });
1336
1337 shutdown_task.await;
1338
1339 if let Some(ep_store) =
1340 cx.update(|cx| EditPredictionStore::try_global(cx))
1341 {
1342 ep_store.update(&mut cx, |store, _| {
1343 store.remove_project(&project);
1344 });
1345 }
1346 }
1347
1348 for example in &mut repo_examples {
1349 example.state.take();
1350 }
1351 finished_examples
1352 .lock()
1353 .unwrap()
1354 .extend_from_slice(&repo_examples);
1355 }
1356 });
1357 }
1358 futures::future::join_all(tasks).await;
1359
1360 Progress::global().finalize();
1361
1362 let is_markdown = args.markdown;
1363 let write_path = in_place_temp_path.as_ref().or(output.as_ref());
1364 match &command {
1365 Command::Predict(args) | Command::Score(args) => {
1366 predict::sync_batches(args.provider.as_ref()).await?;
1367 if args.wait {
1368 predict::wait_for_batches(args.provider.as_ref()).await?;
1369 let mut examples =
1370 std::mem::take(&mut *finished_examples.lock().unwrap());
1371 predict::reprocess_after_batch_wait(&mut examples, args).await?;
1372 rewrite_output(&examples, write_path, is_markdown)?;
1373 *finished_examples.lock().unwrap() = examples;
1374 }
1375 }
1376 Command::Eval(args) => {
1377 predict::sync_batches(args.predict.provider.as_ref()).await?;
1378 if args.predict.wait {
1379 predict::wait_for_batches(args.predict.provider.as_ref()).await?;
1380 let mut examples =
1381 std::mem::take(&mut *finished_examples.lock().unwrap());
1382 predict::reprocess_after_batch_wait(&mut examples, &args.predict)
1383 .await?;
1384 rewrite_output(&examples, write_path, is_markdown)?;
1385 *finished_examples.lock().unwrap() = examples;
1386 }
1387 }
1388 Command::Qa(args) => {
1389 qa::sync_batches(args).await?;
1390 }
1391 Command::Repair(args) => {
1392 repair::sync_batches(args).await?;
1393 if args.wait {
1394 repair::wait_for_batches(args).await?;
1395 let mut examples =
1396 std::mem::take(&mut *finished_examples.lock().unwrap());
1397 repair::reprocess_after_batch_wait(&mut examples, args).await?;
1398 rewrite_output(&examples, write_path, is_markdown)?;
1399 *finished_examples.lock().unwrap() = examples;
1400 }
1401 }
1402 _ => (),
1403 }
1404
1405 match &command {
1406 Command::Eval(args) => {
1407 let examples = finished_examples.lock().unwrap();
1408 score::print_report(&examples, args.verbose);
1409 if let Some(summary_path) = &args.summary_json {
1410 score::write_summary_json(&examples, summary_path)?;
1411 }
1412 }
1413 Command::Repair(args) => {
1414 let examples = finished_examples.lock().unwrap();
1415 repair::print_report(&examples, args.confidence_threshold);
1416 }
1417 _ => (),
1418 };
1419
1420 // For --in-place, atomically rename temp file to original
1421 if let Some(temp_path) = &in_place_temp_path {
1422 let final_path = output.as_ref().expect("in_place_temp_path requires output");
1423 std::fs::rename(temp_path, final_path)
1424 .expect("Failed to rename temp file to final output");
1425 }
1426
1427 anyhow::Ok(())
1428 }
1429 .await;
1430
1431 if let Err(e) = result {
1432 panic!("Fatal error: {:?}", e);
1433 }
1434
1435 let _ = cx.update(|cx| cx.quit());
1436 })
1437 .detach();
1438 });
1439}
1440
1441fn rewrite_output(
1442 examples: &[Example],
1443 output_path: Option<&PathBuf>,
1444 markdown: bool,
1445) -> anyhow::Result<()> {
1446 if markdown {
1447 let dir = output_path.context("--markdown requires -o")?;
1448 for example in examples {
1449 let filename = format!("{}.md", example.spec.filename());
1450 let path = dir.join(&filename);
1451 let markdown = example.spec.to_markdown();
1452 std::fs::write(&path, &markdown).context("Failed to write markdown file")?;
1453 }
1454 } else if let Some(path) = output_path {
1455 let file = OpenOptions::new()
1456 .create(true)
1457 .write(true)
1458 .truncate(true)
1459 .open(path)
1460 .context("Failed to open output file for rewriting")?;
1461 let mut writer = BufWriter::new(file);
1462 for example in examples {
1463 let line = serde_json::to_string(example)?;
1464 writeln!(writer, "{}", line)?;
1465 }
1466 writer.flush()?;
1467 } else {
1468 for example in examples {
1469 let line = serde_json::to_string(example)?;
1470 println!("{}", line);
1471 }
1472 }
1473 Ok(())
1474}
1475
1476async fn handle_error(
1477 error: anyhow::Error,
1478 args: &EpArgs,
1479 command: &Command,
1480 app_state: &Arc<headless::EpAppState>,
1481 failfast_on_single_example: bool,
1482 example: &Example,
1483) {
1484 Progress::global().increment_failed();
1485
1486 let msg;
1487 if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1488 let example_name = example.spec.filename();
1489
1490 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1491 app_state
1492 .fs
1493 .write(
1494 &failed_example_path,
1495 &serde_json::to_vec_pretty(&example).unwrap(),
1496 )
1497 .await
1498 .unwrap();
1499 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1500 app_state
1501 .fs
1502 .write(&err_path, format!("{error:?}").as_bytes())
1503 .await
1504 .unwrap();
1505
1506 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1507 let mut file = OpenOptions::new()
1508 .create(true)
1509 .append(true)
1510 .open(&failed_jsonl_path)
1511 .expect("Failed to open failed.jsonl");
1512 writeln!(file, "{}", serde_json::to_string(example).unwrap())
1513 .expect("Failed to write to failed.jsonl");
1514
1515 let cursor_path = match example.repo_name() {
1516 Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1517 Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1518 };
1519 msg = format!(
1520 indoc::indoc! {"
1521 While processing \"{}\":
1522
1523 \x1b[31m{:?}\x1b[0m
1524
1525 Example: \x1b[36m{}\x1b[0m
1526 Error file: \x1b[36m{}\x1b[0m
1527 Cursor file: \x1b[36m{}\x1b[0m
1528 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1529 "},
1530 example.spec.name,
1531 error,
1532 failed_example_path.display(),
1533 err_path.display(),
1534 cursor_path.display(),
1535 command,
1536 failed_example_path.display(),
1537 );
1538 } else {
1539 msg = format!(
1540 indoc::indoc! {"
1541 While processing \"{}\":
1542
1543 \x1b[31m{:?}\x1b[0m
1544 "},
1545 example.spec.name, error
1546 );
1547 }
1548
1549 if args.failfast || failfast_on_single_example {
1550 Progress::global().finalize();
1551 panic!("{}", msg);
1552 } else {
1553 log::error!("{}", msg);
1554 }
1555}