1mod anthropic_client;
2mod distill;
3mod example;
4mod filter_languages;
5mod format_prompt;
6mod git;
7mod headless;
8
9mod load_project;
10mod metrics;
11mod openai_client;
12mod parse_output;
13mod paths;
14mod predict;
15mod progress;
16mod prompt_assets;
17mod pull_examples;
18mod qa;
19mod reorder_patch;
20mod repair;
21mod retrieve_context;
22mod reversal_tracking;
23mod score;
24mod split_commit;
25mod split_dataset;
26
27mod synthesize;
28mod truncate_expected_patch;
29mod word_diff;
30use anyhow::Context as _;
31use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
32use collections::{HashMap, HashSet};
33use edit_prediction::EditPredictionStore;
34use futures::channel::mpsc;
35use futures::{SinkExt as _, StreamExt as _};
36use gaoya::minhash::{
37 MinHashIndex, MinHasher, MinHasher32, calculate_minhash_params, compute_minhash_similarity,
38};
39use gpui::{AppContext as _, BackgroundExecutor, Task};
40use zeta_prompt::ZetaFormat;
41
42use reqwest_client::ReqwestClient;
43use serde::{Deserialize, Deserializer, Serialize, Serializer};
44use std::env;
45use std::fmt::Display;
46use std::fs::{File, OpenOptions};
47use std::hash::{Hash, Hasher};
48use std::io::{BufRead, BufReader, BufWriter, Write};
49use std::str::FromStr;
50use std::sync::Mutex;
51use std::{path::PathBuf, sync::Arc};
52
53use crate::distill::run_distill;
54use crate::example::{Example, group_examples_by_repo, read_example_files};
55use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
56use crate::format_prompt::run_format_prompt;
57use crate::load_project::run_load_project;
58use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
59use crate::predict::run_prediction;
60use crate::progress::Progress;
61use crate::pull_examples::{fetch_settled_examples_after, parse_settled_after_input};
62use crate::retrieve_context::run_context_retrieval;
63use crate::score::run_scoring;
64use crate::split_commit::SplitCommitArgs;
65use crate::split_dataset::SplitArgs;
66use crate::synthesize::{SynthesizeConfig, run_synthesize};
67use crate::truncate_expected_patch::TruncatePatchArgs;
68
69#[derive(Parser, Debug)]
70#[command(name = "ep")]
71struct EpArgs {
72 #[arg(long, default_value_t = false)]
73 printenv: bool,
74 #[clap(long, default_value_t = 10, global = true)]
75 max_parallelism: usize,
76 /// The limit for the number of examples to process
77 /// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
78 #[clap(long, global = true)]
79 limit: Option<usize>,
80 #[clap(long, global = true)]
81 offset: Option<usize>,
82 /// Filter examples by name
83 #[clap(long, global = true)]
84 name: Option<String>,
85 /// Filter examples by repository
86 #[clap(long, global = true)]
87 repo: Option<String>,
88 /// Deduplicate by cursor position and keep at most this many examples per cluster
89 #[clap(long, global = true)]
90 max_duplicates: Option<usize>,
91 #[command(subcommand)]
92 command: Option<Command>,
93 /// Input file paths
94 #[clap(global = true)]
95 inputs: Vec<PathBuf>,
96 #[arg(long, short, global = true)]
97 output: Option<PathBuf>,
98 #[arg(long, short, global = true)]
99 in_place: bool,
100 #[arg(long, short, global = true)]
101 failfast: bool,
102 /// How to handle failed examples in output: keep them or skip them.
103 /// Failed examples are always logged to the run's failed directory.
104 #[arg(long, global = true, default_value = "keep")]
105 failed: FailedHandling,
106 /// Output as markdown files instead of JSONL. When set, -o specifies a directory
107 /// where one .md file per example will be written (named after each example).
108 #[arg(long, short, global = true)]
109 markdown: bool,
110}
111
112/// Controls whether failed examples are included in the main output.
113/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
114#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
115pub enum FailedHandling {
116 /// Include failed examples in the main output (default)
117 #[default]
118 Keep,
119 /// Exclude failed examples from the main output
120 Skip,
121 /// Skip writing files
122 SkipNoFiles,
123}
124
125const INPUTS_HELP: &str = r#"
126Inputs can be file paths or special specifiers:
127
128 path
129 Path to an example(s) file (.md, .json, or .jsonl)
130
131 captured-after:{timestamp}
132 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
133 These are examples captured via the "Capture Edit Prediction Example" action.
134
135 rejected-after:{timestamp}
136 Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
137 These are predictions that were shown to users but rejected (useful for DPO training).
138
139 settled-after:{timestamp}
140 Fetch settled stream examples from Snowflake after the given RFC3339 timestamp.
141 These are examples from the edit prediction settled stream.
142
143 rated-after:{timestamp}
144 Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
145 These are predictions that users explicitly rated as positive or negative via the
146 rate completions modal. Only zeta2 predictions are included.
147 - Positive ratings: output becomes expected_patches
148 - Negative ratings: output becomes rejected_patch
149
150 rated-positive-after:{timestamp}
151 Same as rated-after, but only fetches positively rated predictions.
152
153 rated-negative-after:{timestamp}
154 Same as rated-after, but only fetches negatively rated predictions.
155
156 Required environment variables to connect to Snowflake:
157 EP_SNOWFLAKE_API_KEY
158 EP_SNOWFLAKE_BASE_URL
159
160 Optional:
161 EP_SNOWFLAKE_ROLE
162
163Examples:
164
165 # Read examples from a file
166 ep read examples.jsonl -o output.jsonl
167
168 # Read captured examples after a timestamp
169 ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
170
171 # Read rejected predictions for DPO training
172 ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
173
174 # Read user-rated predictions
175 ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
176
177 # Read settled stream examples
178 ep read settled-after:2025-01-01T00:00:00Z -o settled.jsonl
179
180 # Read only positively rated predictions
181 ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
182
183 # Read only negatively rated predictions
184 ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
185
186 # Mix multiple input sources
187 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
188"#;
189
190#[derive(Subcommand, Debug, Clone)]
191enum Command {
192 /// Read examples from files or fetch from Snowflake, output as .jsonl
193 Read(ReadArgs),
194 /// Create git worktrees for each example and load file contents
195 LoadProject,
196 /// Retrieve context for input examples.
197 Context,
198 /// Generate a prompt string for a specific model
199 FormatPrompt(FormatPromptArgs),
200 /// Runs edit prediction
201 Predict(PredictArgs),
202 /// Parse model outputs (actual_output) into unified diffs (actual_patch).
203 /// Requires format-prompt to have been run first. Uses provider from prompt.
204 ParseOutput,
205 /// Computes a score based on actual and expected patches
206 Score(PredictArgs),
207 /// Prepares a distillation dataset by copying expected outputs to
208 /// predicted outputs and removing actual outputs and prompts.
209 Distill,
210 /// Print aggregated scores
211 Eval(EvalArgs),
212 /// Generate eval examples by analyzing git commits from a repository
213 Synthesize(SynthesizeArgs),
214 /// Remove git repositories and worktrees
215 Clean,
216 /// Generate an evaluation example by splitting a chronologically-ordered commit
217 SplitCommit(SplitCommitArgs),
218 /// Truncate expected patch by the given criteria
219 TruncatePatch(TruncatePatchArgs),
220 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
221 Split(SplitArgs),
222 /// Filter a JSONL dataset by programming language (based on cursor_path extension)
223 FilterLanguages(FilterLanguagesArgs),
224 /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
225 ImportBatch(ImportBatchArgs),
226 /// Assess the quality of predictions using LLM-as-a-judge
227 Qa(qa::QaArgs),
228 /// Repair predictions that received poor QA scores by generating improved predictions
229 Repair(repair::RepairArgs),
230 /// Print all valid zeta formats (lowercase, one per line)
231 PrintZetaFormats,
232}
233
234impl Display for Command {
235 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236 match self {
237 Command::Read(_) => write!(f, "read"),
238 Command::LoadProject => write!(f, "load-project"),
239 Command::Context => write!(f, "context"),
240 Command::FormatPrompt(args) => {
241 write!(f, "format-prompt --provider={}", args.provider)
242 }
243 Command::Predict(args) => match &args.provider {
244 Some(provider) => write!(f, "predict --provider={}", provider),
245 None => write!(f, "predict"),
246 },
247 Command::ParseOutput => write!(f, "parse-output"),
248 Command::Score(args) => match &args.provider {
249 Some(provider) => write!(f, "score --provider={}", provider),
250 None => write!(f, "score"),
251 },
252 Command::Distill => write!(f, "distill"),
253 Command::Eval(args) => match &args.predict.provider {
254 Some(provider) => write!(f, "eval --provider={}", provider),
255 None => write!(f, "eval"),
256 },
257 Command::Synthesize(args) => {
258 write!(f, "synthesize --repos {}", args.repos.join(" "))
259 }
260 Command::Clean => write!(f, "clean"),
261 Command::SplitCommit(_) => write!(f, "split-commit"),
262 Command::TruncatePatch(_) => write!(f, "truncate-patch"),
263 Command::Split(_) => write!(f, "split"),
264 Command::FilterLanguages(_) => write!(f, "filter-languages"),
265 Command::ImportBatch(args) => {
266 write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
267 }
268 Command::Qa(_) => {
269 write!(f, "qa")
270 }
271 Command::Repair(_) => {
272 write!(f, "repair")
273 }
274 Command::PrintZetaFormats => {
275 write!(f, "print-zeta-formats")
276 }
277 }
278 }
279}
280
281#[derive(Debug, Args, Clone)]
282#[command(after_help = INPUTS_HELP)]
283struct ReadArgs {}
284
285#[derive(Debug, Args, Clone)]
286struct FormatPromptArgs {
287 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
288 provider: PredictionProvider,
289}
290
291#[derive(Debug, Args, Clone)]
292struct PredictArgs {
293 #[clap(long, short('p'))]
294 provider: Option<PredictionProvider>,
295 #[clap(long, default_value_t = 1)]
296 repetitions: usize,
297 /// Only use cached responses, don't queue new requests for batching
298 #[clap(long)]
299 cache_only: bool,
300 /// Wait for all batches to complete before exiting (only applies to batched providers like teacher)
301 #[clap(long)]
302 wait: bool,
303}
304
305#[derive(Debug, Args, Clone)]
306struct EvalArgs {
307 #[clap(flatten)]
308 predict: PredictArgs,
309 /// Path to write summary scores as JSON
310 #[clap(long)]
311 summary_json: Option<PathBuf>,
312 /// Print all individual example lines (default: up to 20)
313 #[clap(long)]
314 verbose: bool,
315}
316
317#[derive(Clone, Copy, Default, Debug, PartialEq, Eq, Hash)]
318pub enum TeacherBackend {
319 Sonnet46,
320 #[default]
321 Sonnet45,
322 Gpt52,
323}
324
325impl std::fmt::Display for TeacherBackend {
326 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
327 match self {
328 TeacherBackend::Sonnet46 => write!(f, "sonnet46"),
329 TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
330 TeacherBackend::Gpt52 => write!(f, "gpt52"),
331 }
332 }
333}
334
335impl std::str::FromStr for TeacherBackend {
336 type Err = anyhow::Error;
337
338 fn from_str(s: &str) -> Result<Self, Self::Err> {
339 match s.to_lowercase().as_str() {
340 "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
341 "sonnet46" => Ok(TeacherBackend::Sonnet46),
342 "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
343 "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
344 _ => anyhow::bail!(
345 "unknown teacher backend `{s}`. Valid options: sonnet45, sonnet46, gpt52"
346 ),
347 }
348 }
349}
350
351impl TeacherBackend {
352 pub fn model_name(&self) -> &'static str {
353 match self {
354 TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
355 TeacherBackend::Sonnet46 => "claude-sonnet-4-6",
356 TeacherBackend::Gpt52 => "gpt-5.2",
357 }
358 }
359}
360
361#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
362enum PredictionProvider {
363 Mercury,
364 Zeta1,
365 Zeta2(ZetaFormat),
366 Baseten(ZetaFormat),
367 Teacher(TeacherBackend, ZetaFormat),
368 TeacherMultiRegion(TeacherBackend),
369 TeacherNonBatching(TeacherBackend, ZetaFormat),
370 TeacherMultiRegionNonBatching(TeacherBackend),
371 Repair,
372}
373
374impl Default for PredictionProvider {
375 fn default() -> Self {
376 PredictionProvider::Zeta2(ZetaFormat::default())
377 }
378}
379
380impl std::fmt::Display for PredictionProvider {
381 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
382 match self {
383 PredictionProvider::Mercury => write!(f, "mercury"),
384 PredictionProvider::Zeta1 => write!(f, "zeta1"),
385 PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"),
386 PredictionProvider::Baseten(format) => write!(f, "baseten:{format}"),
387 PredictionProvider::Teacher(backend, format) => {
388 write!(f, "teacher:{backend}:{format:?}")
389 }
390 PredictionProvider::TeacherMultiRegion(backend) => {
391 write!(f, "teacher-multi-region:{backend}")
392 }
393 PredictionProvider::TeacherNonBatching(backend, format) => {
394 write!(f, "teacher-non-batching:{backend}:{format:?}")
395 }
396 PredictionProvider::TeacherMultiRegionNonBatching(backend) => {
397 write!(f, "teacher-multi-region-non-batching:{backend}")
398 }
399 PredictionProvider::Repair => write!(f, "repair"),
400 }
401 }
402}
403
404impl std::str::FromStr for PredictionProvider {
405 type Err = anyhow::Error;
406
407 fn from_str(s: &str) -> Result<Self, Self::Err> {
408 let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
409
410 let provider_lower = provider.to_lowercase();
411 match provider_lower.as_str() {
412 "mercury" => Ok(PredictionProvider::Mercury),
413 "zeta1" => Ok(PredictionProvider::Zeta1),
414 "zeta2" => {
415 let format = arg.map(ZetaFormat::parse).transpose()?.unwrap_or_default();
416 Ok(PredictionProvider::Zeta2(format))
417 }
418 "teacher" => parse_teacher_args(arg),
419 "teacher-non-batching" | "teacher_non_batching" => {
420 let backend = arg
421 .map(|a| a.parse())
422 .transpose()?
423 .unwrap_or(TeacherBackend::default());
424 Ok(PredictionProvider::TeacherNonBatching(
425 backend,
426 ZetaFormat::default(),
427 ))
428 }
429 "teacher-multi-region" | "teacher_multi_region" => {
430 let backend = arg
431 .map(|a| a.parse())
432 .transpose()?
433 .unwrap_or(TeacherBackend::default());
434 Ok(PredictionProvider::TeacherMultiRegion(backend))
435 }
436 "teacher-multi-region-non-batching" | "teacher_multi_region_non_batching" => {
437 let backend = arg
438 .map(|a| a.parse())
439 .transpose()?
440 .unwrap_or(TeacherBackend::default());
441 Ok(PredictionProvider::TeacherMultiRegionNonBatching(backend))
442 }
443 "repair" => Ok(PredictionProvider::Repair),
444 "baseten" => {
445 let format = arg
446 .map(ZetaFormat::parse)
447 .transpose()?
448 .unwrap_or(ZetaFormat::default());
449 Ok(PredictionProvider::Baseten(format))
450 }
451 _ => {
452 anyhow::bail!(
453 "unknown provider `{provider}`. Valid options: mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-multi-region, teacher-multi-region:<backend>, teacher-non-batching, teacher-multi-region-non-batching, repair\n\
454 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
455 For teacher providers, you can specify a backend like `teacher:sonnet46`, `teacher-multi-region:sonnet46`, `teacher-multi-region-non-batching:sonnet46`, or `teacher:gpt52`.\n\
456 Available zeta versions:\n{}",
457 ZetaFormat::options_as_string()
458 )
459 }
460 }
461 }
462}
463
464fn parse_teacher_args(arg: Option<&str>) -> Result<PredictionProvider, anyhow::Error> {
465 let mut backend = TeacherBackend::default();
466 let mut format = ZetaFormat::default();
467
468 for arg in arg.unwrap_or_default().split(':') {
469 if arg.is_empty() {
470 continue;
471 }
472
473 if let Ok(parsed_backend) = TeacherBackend::from_str(arg) {
474 backend = parsed_backend;
475 } else if let Ok(parsed_format) = ZetaFormat::parse(arg) {
476 format = parsed_format;
477 } else {
478 anyhow::bail!("unknown teacher backend or zeta format `{arg}`");
479 }
480 }
481
482 Ok(PredictionProvider::Teacher(backend, format))
483}
484
485impl Serialize for PredictionProvider {
486 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
487 where
488 S: Serializer,
489 {
490 serializer.serialize_str(&self.to_string())
491 }
492}
493
494impl<'de> Deserialize<'de> for PredictionProvider {
495 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
496 where
497 D: Deserializer<'de>,
498 {
499 let s = String::deserialize(deserializer)?;
500 s.parse().map_err(serde::de::Error::custom)
501 }
502}
503
504#[derive(Debug, Args, Clone)]
505struct SynthesizeArgs {
506 /// Repository URLs (git@github.com:owner/repo or https://...)
507 #[clap(long, required = true, num_args = 1..)]
508 repos: Vec<String>,
509
510 /// Number of examples to generate per repository
511 #[clap(long, default_value_t = 5)]
512 count: usize,
513
514 /// Maximum commits to scan per repository before giving up
515 #[clap(long, default_value_t = 100)]
516 max_commits: usize,
517
518 /// Ignore state file and reprocess all commits
519 #[clap(long)]
520 fresh: bool,
521}
522
523#[derive(Debug, Args, Clone)]
524struct ImportBatchArgs {
525 /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
526 #[clap(long, required = true, num_args = 1..)]
527 batch_ids: Vec<String>,
528 /// Which provider's batches to import (anthropic or openai)
529 #[clap(long, default_value = "anthropic")]
530 provider: BatchProvider,
531}
532
533#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
534enum BatchProvider {
535 Anthropic,
536 Openai,
537}
538
539#[cfg(test)]
540mod tests {
541 use super::*;
542
543 #[test]
544 fn prediction_provider_multi_region_non_batched_round_trips_to_primary_spelling() {
545 let provider: PredictionProvider = "teacher-multi-region-non-batching:sonnet46"
546 .parse()
547 .unwrap();
548 assert_eq!(
549 provider,
550 PredictionProvider::TeacherMultiRegionNonBatching(TeacherBackend::Sonnet46)
551 );
552 assert_eq!(
553 provider.to_string(),
554 "teacher-multi-region-non-batching:sonnet46"
555 );
556 }
557
558 #[test]
559 fn prediction_provider_multi_region_non_batched_alias_round_trips_to_primary_spelling() {
560 let provider: PredictionProvider =
561 "teacher_multi_region_non_batching:gpt52".parse().unwrap();
562 assert_eq!(
563 provider,
564 PredictionProvider::TeacherMultiRegionNonBatching(TeacherBackend::Gpt52)
565 );
566 assert_eq!(
567 provider.to_string(),
568 "teacher-multi-region-non-batching:gpt52"
569 );
570 }
571}
572
573impl EpArgs {
574 fn output_path(&self) -> Option<PathBuf> {
575 if self.in_place {
576 if self.inputs.len() == 1 {
577 self.inputs.first().cloned()
578 } else {
579 panic!("--in-place requires exactly one input file")
580 }
581 } else {
582 self.output.clone()
583 }
584 }
585}
586
587/// Minimum Zed version required for Snowflake queries.
588/// This version introduced the current request schema with predicted edits in the edit
589/// history, and open source repos distinguished.
590const MIN_CAPTURE_VERSION: pull_examples::MinCaptureVersion = pull_examples::MinCaptureVersion {
591 minor: 224,
592 patch: 1,
593};
594
595fn deduplicate_examples(examples: &mut Vec<Example>, max_per_cluster: usize) {
596 let total_before_exact = examples.len();
597 let mut seen_positions = HashSet::default();
598 examples.retain(|example| seen_positions.insert(example.spec.cursor_position.clone()));
599 log::info!(
600 "exact duplicate filter: {total_before_exact} examples → {} examples ({} removed)",
601 examples.len(),
602 total_before_exact - examples.len(),
603 );
604
605 const JACCARD_THRESHOLD: f64 = 0.5;
606 const NUM_HASHES: usize = 128;
607 const TOKEN_NGRAM_SIZE: usize = 5;
608
609 let (num_bands, band_width) = calculate_minhash_params(JACCARD_THRESHOLD, NUM_HASHES);
610 let num_hashes = num_bands * band_width;
611 let minhasher = MinHasher32::new(num_hashes);
612 let mut index: MinHashIndex<u32, usize> =
613 MinHashIndex::new(num_bands, band_width, JACCARD_THRESHOLD);
614
615 let signatures: Vec<Vec<u32>> = examples
616 .iter()
617 .map(|example| {
618 let shingles = code_token_ngrams(&example.spec.cursor_position, TOKEN_NGRAM_SIZE);
619 minhasher.create_signature(shingles.iter())
620 })
621 .collect();
622
623 for (id, signature) in signatures.iter().enumerate() {
624 index.insert(id, signature.clone());
625 }
626
627 // Build clusters via union-find on LSH candidate pairs.
628 let mut parent: Vec<usize> = (0..examples.len()).collect();
629
630 fn find(parent: &mut Vec<usize>, mut x: usize) -> usize {
631 while parent[x] != x {
632 parent[x] = parent[parent[x]];
633 x = parent[x];
634 }
635 x
636 }
637
638 for (id, signature) in signatures.iter().enumerate() {
639 for candidate in index.query_owned(signature) {
640 let (a, b) = (find(&mut parent, id), find(&mut parent, candidate));
641 if a != b {
642 parent[a] = b;
643 }
644 }
645 }
646
647 let mut clusters: HashMap<usize, Vec<usize>> = HashMap::default();
648 for id in 0..examples.len() {
649 clusters.entry(find(&mut parent, id)).or_default().push(id);
650 }
651
652 let mut keep: HashSet<usize> = HashSet::default();
653 for members in clusters.values() {
654 let selected = greedy_max_min_diverse(members, &signatures, max_per_cluster);
655 keep.extend(selected);
656 }
657
658 let total = examples.len();
659 let mut kept_indices: Vec<usize> = keep.into_iter().collect();
660 kept_indices.sort();
661
662 let mut retained = Vec::with_capacity(kept_indices.len());
663 for index in kept_indices.into_iter().rev() {
664 retained.push(examples.swap_remove(index));
665 }
666 retained.reverse();
667
668 *examples = retained;
669 log::info!(
670 "near-duplicate filter: {total} examples → {} examples ({} removed)",
671 examples.len(),
672 total - examples.len(),
673 );
674}
675
676fn greedy_max_min_diverse(members: &[usize], signatures: &[Vec<u32>], k: usize) -> Vec<usize> {
677 if members.len() <= k {
678 return members.to_vec();
679 }
680
681 let mut selected = vec![members[0]];
682 let mut min_dist: HashMap<usize, f64> = HashMap::default();
683 for &member in &members[1..] {
684 let dist = 1.0 - compute_minhash_similarity(&signatures[selected[0]], &signatures[member]);
685 min_dist.insert(member, dist);
686 }
687
688 while selected.len() < k {
689 let &best = min_dist
690 .iter()
691 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
692 .map(|(id, _)| id)
693 .expect("min_dist should not be empty when selected.len() < k");
694 selected.push(best);
695 min_dist.remove(&best);
696
697 let best_sig = &signatures[best];
698 for (member, current_min) in min_dist.iter_mut() {
699 let dist = 1.0 - compute_minhash_similarity(best_sig, &signatures[*member]);
700 if dist < *current_min {
701 *current_min = dist;
702 }
703 }
704 }
705
706 selected
707}
708
709fn code_token_ngrams(code: &str, ngram_size: usize) -> Vec<String> {
710 let tokens: Vec<&str> = word_diff::tokenize(code)
711 .into_iter()
712 .filter(|t| !t.trim().is_empty())
713 .collect();
714
715 if tokens.len() < ngram_size {
716 return vec![tokens.join("\0")];
717 }
718
719 tokens
720 .windows(ngram_size)
721 .map(|window| window.join("\0"))
722 .collect()
723}
724
725async fn load_examples(
726 http_client: Arc<dyn http_client::HttpClient>,
727 args: &EpArgs,
728 output_path: Option<&PathBuf>,
729 background_executor: BackgroundExecutor,
730) -> anyhow::Result<Vec<Example>> {
731 let mut captured_after_timestamps = Vec::new();
732 let mut rejected_after_timestamps = Vec::new();
733 let mut requested_after_timestamps = Vec::new();
734 let mut settled_after_timestamps = Vec::new();
735 let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
736 Vec::new();
737 let mut file_inputs = Vec::new();
738
739 for input in &args.inputs {
740 let input_string = input.to_string_lossy();
741 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
742 captured_after_timestamps.push(timestamp.to_string());
743 } else if let Some(timestamp) =
744 pull_examples::parse_rejected_after_input(input_string.as_ref())
745 {
746 rejected_after_timestamps.push(timestamp.to_string());
747 } else if let Some(timestamp) =
748 pull_examples::parse_requested_after_input(input_string.as_ref())
749 {
750 requested_after_timestamps.push(timestamp.to_string());
751 } else if let Some(timestamp) = parse_settled_after_input(input_string.as_ref()) {
752 settled_after_timestamps.push(timestamp.to_string());
753 } else if let Some((timestamp, rating_filter)) =
754 pull_examples::parse_rated_after_input(input_string.as_ref())
755 {
756 rated_after_inputs.push((timestamp.to_string(), rating_filter));
757 } else {
758 file_inputs.push(input.clone());
759 }
760 }
761
762 let mut examples = read_example_files(&file_inputs);
763
764 // Apply offset to file examples first, then pass remaining offset to Snowflake.
765 let file_example_count = examples.len();
766 let remaining_offset = if let Some(offset) = args.offset {
767 if offset >= file_example_count {
768 examples.clear();
769 offset - file_example_count
770 } else {
771 examples.splice(0..offset, []);
772 0
773 }
774 } else {
775 0
776 };
777
778 Progress::global().set_total_examples(examples.len());
779
780 let remaining_limit_for_snowflake =
781 args.limit.map(|limit| limit.saturating_sub(examples.len()));
782
783 if let Some(0) = remaining_limit_for_snowflake {
784 log::info!(
785 "skipping Snowflake inputs because --limit is already satisfied by example files"
786 );
787 } else {
788 let max_rows_per_timestamp = remaining_limit_for_snowflake;
789
790 if !rejected_after_timestamps.is_empty() {
791 rejected_after_timestamps.sort();
792
793 let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
794 http_client.clone(),
795 &rejected_after_timestamps,
796 max_rows_per_timestamp,
797 remaining_offset,
798 background_executor.clone(),
799 Some(MIN_CAPTURE_VERSION),
800 )
801 .await?;
802 examples.append(&mut rejected_examples);
803 }
804
805 if !requested_after_timestamps.is_empty() {
806 requested_after_timestamps.sort();
807
808 let mut requested_examples = pull_examples::fetch_requested_examples_after(
809 http_client.clone(),
810 &requested_after_timestamps,
811 max_rows_per_timestamp,
812 remaining_offset,
813 background_executor.clone(),
814 Some(MIN_CAPTURE_VERSION),
815 )
816 .await?;
817 examples.append(&mut requested_examples);
818 }
819
820 if !captured_after_timestamps.is_empty() {
821 captured_after_timestamps.sort();
822
823 let mut captured_examples = pull_examples::fetch_captured_examples_after(
824 http_client.clone(),
825 &captured_after_timestamps,
826 max_rows_per_timestamp,
827 remaining_offset,
828 background_executor.clone(),
829 Some(MIN_CAPTURE_VERSION),
830 )
831 .await?;
832 examples.append(&mut captured_examples);
833 }
834
835 if !settled_after_timestamps.is_empty() {
836 settled_after_timestamps.sort();
837
838 let mut settled_examples = fetch_settled_examples_after(
839 http_client.clone(),
840 &settled_after_timestamps,
841 max_rows_per_timestamp,
842 remaining_offset,
843 background_executor.clone(),
844 Some(MIN_CAPTURE_VERSION),
845 )
846 .await?;
847 examples.append(&mut settled_examples);
848 }
849
850 if !rated_after_inputs.is_empty() {
851 rated_after_inputs.sort();
852
853 let mut rated_examples = pull_examples::fetch_rated_examples_after(
854 http_client,
855 &rated_after_inputs,
856 max_rows_per_timestamp,
857 remaining_offset,
858 background_executor,
859 Some(MIN_CAPTURE_VERSION),
860 )
861 .await?;
862 examples.append(&mut rated_examples);
863 }
864 }
865
866 crate::example::sort_examples_by_repo_and_rev(&mut examples);
867
868 if let Some(name_filter) = &args.name {
869 examples.retain(|example| example.spec.name.contains(name_filter));
870 }
871 if let Some(repo_filter) = &args.repo {
872 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
873 }
874
875 // Skip resume logic for --in-place since input and output are the same file,
876 // which would incorrectly treat all input examples as already processed.
877 if !args.in_place {
878 if let Some(path) = output_path
879 && let Some(command) = &args.command
880 {
881 resume_from_output(path, &mut examples, command);
882 }
883 }
884
885 if let Some(max_duplicates) = args.max_duplicates {
886 deduplicate_examples(&mut examples, max_duplicates);
887 }
888
889 if let Some(limit) = args.limit {
890 examples.truncate(limit);
891 }
892
893 let progress = Progress::global();
894 progress.set_total_examples(examples.len());
895 progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
896
897 Ok(examples)
898}
899
900fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
901 let mut hasher = collections::FxHasher::default();
902 spec.hash(&mut hasher);
903 hasher.finish()
904}
905
906fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>, command: &Command) {
907 let file = match File::open(path) {
908 Ok(f) => f,
909 Err(_) => return,
910 };
911
912 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
913
914 let reader = BufReader::new(file);
915 let mut kept_lines = Vec::new();
916 let mut kept_hashes = HashSet::default();
917
918 for line in reader.lines() {
919 let line = match line {
920 Ok(l) => l,
921 Err(_) => continue,
922 };
923
924 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
925 let hash = spec_hash(&output_example.spec);
926 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
927 let is_complete = match command {
928 Command::Qa(_) => output_example
929 .qa
930 .first()
931 .and_then(|q| q.as_ref())
932 .and_then(|q| q.confidence)
933 .is_some(),
934 Command::Repair(_) => output_example.predictions.iter().any(|p| {
935 p.provider == PredictionProvider::Repair && p.actual_patch.is_some()
936 }),
937 _ => true,
938 };
939 if is_complete {
940 kept_hashes.insert(hash);
941 kept_lines.push(line);
942 }
943 }
944 }
945 }
946
947 let total = examples.len();
948 let already_processed = kept_hashes.len();
949
950 eprintln!(
951 "Resuming: {}/{} examples already processed",
952 already_processed, total
953 );
954
955 let file = OpenOptions::new()
956 .write(true)
957 .truncate(true)
958 .open(path)
959 .expect("Failed to open output file for rewriting");
960 let mut writer = BufWriter::new(file);
961 for line in &kept_lines {
962 writeln!(writer, "{}", line).expect("Failed to write to output file");
963 }
964 writer.flush().expect("Failed to flush output file");
965
966 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
967}
968
969fn main() {
970 let args = EpArgs::parse();
971
972 if args.printenv {
973 ::util::shell_env::print_env();
974 return;
975 }
976
977 let output = args.output_path();
978
979 if args.markdown && output.is_none() {
980 eprintln!("--markdown requires -o to specify the output directory");
981 std::process::exit(1);
982 }
983
984 let command = match &args.command {
985 Some(cmd) => cmd.clone(),
986 None => {
987 EpArgs::command().print_help().unwrap();
988 return;
989 }
990 };
991
992 match &command {
993 Command::ImportBatch(import_args) => {
994 smol::block_on(async {
995 match import_args.provider {
996 BatchProvider::Anthropic => {
997 let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
998 .expect("Failed to create Anthropic client");
999 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
1000 eprintln!("Error importing Anthropic batches: {:?}", e);
1001 std::process::exit(1);
1002 }
1003 }
1004 BatchProvider::Openai => {
1005 let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
1006 .expect("Failed to create OpenAI client");
1007 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
1008 eprintln!("Error importing OpenAI batches: {:?}", e);
1009 std::process::exit(1);
1010 }
1011 }
1012 }
1013 println!(
1014 "Successfully imported {} batch(es)",
1015 import_args.batch_ids.len()
1016 );
1017 });
1018 return;
1019 }
1020 Command::Clean => {
1021 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
1022 return;
1023 }
1024 Command::PrintZetaFormats => {
1025 use strum::IntoEnumIterator as _;
1026 for format in ZetaFormat::iter() {
1027 println!("{}", format.to_string().to_lowercase());
1028 }
1029 return;
1030 }
1031
1032 Command::Synthesize(synth_args) => {
1033 let output_dir = if let Some(output_dir) = args.output {
1034 output_dir
1035 } else {
1036 let default_output_dir = env::current_dir()
1037 .unwrap()
1038 .join("crates/edit_prediction_cli/evals-generated");
1039 if default_output_dir.parent().unwrap().exists() {
1040 std::fs::create_dir(&default_output_dir).ok();
1041 default_output_dir
1042 } else {
1043 panic!("output dir is required");
1044 }
1045 };
1046 let config = SynthesizeConfig {
1047 repo_urls: synth_args.repos.clone(),
1048 count: synth_args.count,
1049 max_commits: synth_args.max_commits,
1050 output_dir,
1051 fresh: synth_args.fresh,
1052 };
1053 smol::block_on(async {
1054 if let Err(e) = run_synthesize(config).await {
1055 eprintln!("Error: {:?}", e);
1056 std::process::exit(1);
1057 }
1058 });
1059 return;
1060 }
1061 Command::SplitCommit(split_commit_args) => {
1062 if let Err(error) = split_commit::run_split_commit(
1063 split_commit_args,
1064 &args.inputs,
1065 output.as_ref(),
1066 args.failed,
1067 ) {
1068 eprintln!("{error:#}");
1069 std::process::exit(1);
1070 }
1071 return;
1072 }
1073 Command::TruncatePatch(truncate_args) => {
1074 if let Err(error) =
1075 truncate_expected_patch::run_truncate_expected_patch(truncate_args, &args.inputs)
1076 {
1077 eprintln!("{error:#}");
1078 std::process::exit(1);
1079 }
1080 return;
1081 }
1082 Command::Split(split_args) => {
1083 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
1084 eprintln!("{error:#}");
1085 std::process::exit(1);
1086 }
1087 return;
1088 }
1089 Command::FilterLanguages(filter_args) => {
1090 if let Err(error) =
1091 run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
1092 {
1093 eprintln!("{error:#}");
1094 std::process::exit(1);
1095 }
1096 return;
1097 }
1098
1099 _ => {}
1100 }
1101
1102 let http_client = Arc::new(ReqwestClient::new());
1103 let app = gpui_platform::headless().with_http_client(http_client);
1104
1105 app.run(move |cx| {
1106 let app_state = Arc::new(headless::init(cx));
1107 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
1108
1109 cx.spawn(async move |cx| {
1110 let result = async {
1111 let examples = load_examples(
1112 app_state.client.http_client(),
1113 &args,
1114 output.as_ref(),
1115 cx.background_executor().clone(),
1116 )
1117 .await?;
1118
1119 match &command {
1120 Command::Predict(args) | Command::Score(args) => {
1121 predict::sync_batches(args.provider.as_ref()).await?;
1122 }
1123 Command::Eval(args) => {
1124 predict::sync_batches(args.predict.provider.as_ref()).await?;
1125 }
1126 Command::Qa(args) => {
1127 qa::sync_batches(args).await?;
1128 }
1129 Command::Repair(args) => {
1130 repair::sync_batches(args).await?;
1131 }
1132 _ => (),
1133 }
1134
1135 let failfast_on_single_example = examples.len() == 1;
1136
1137 // For --markdown mode, create the output directory if it doesn't exist
1138 if args.markdown {
1139 let dir = output.as_ref().expect("--markdown requires -o");
1140 if !dir.exists() {
1141 std::fs::create_dir_all(dir)
1142 .expect("Failed to create markdown output directory");
1143 }
1144 }
1145
1146 // Set up JSONL output writer (not used in markdown mode)
1147 let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
1148 let mut in_place_temp_path: Option<PathBuf> = None;
1149 if !args.markdown
1150 && let Some(output_path) = output.as_ref()
1151 {
1152 let write_path = if args.in_place {
1153 let temp = output_path.with_extension("jsonl.tmp");
1154 in_place_temp_path = Some(temp.clone());
1155 temp
1156 } else {
1157 output_path.clone()
1158 };
1159
1160 let file = OpenOptions::new()
1161 .create(true)
1162 .write(true)
1163 .truncate(args.in_place)
1164 .append(!args.in_place)
1165 .open(&write_path)
1166 .expect("Failed to open output file");
1167
1168 let mut writer = BufWriter::new(file);
1169 let (sender, mut receiver) = mpsc::unbounded::<String>();
1170 cx.background_spawn(async move {
1171 while let Some(line) = receiver.next().await {
1172 writeln!(writer, "{}", line).expect("Failed to write example");
1173 writer.flush().expect("Failed to flush output");
1174 }
1175 })
1176 .detach();
1177 output_sender = Some(sender);
1178 }
1179
1180 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
1181 let finished_examples = Mutex::new(Vec::new());
1182
1183 let mut tasks = Vec::new();
1184 for _ in 0..args.max_parallelism {
1185 tasks.push(async {
1186 loop {
1187 let Some(mut repo_examples) =
1188 grouped_examples.lock().unwrap().pop_front()
1189 else {
1190 break;
1191 };
1192 for example in &mut repo_examples {
1193 let example_progress =
1194 Progress::global().start_group(&example.spec.name);
1195
1196 let result = async {
1197 match &command {
1198 Command::Read(_) => {}
1199 Command::LoadProject => {
1200 run_load_project(
1201 example,
1202 app_state.clone(),
1203 &example_progress,
1204 cx.clone(),
1205 )
1206 .await?;
1207 }
1208 Command::Context => {
1209 run_context_retrieval(
1210 example,
1211 app_state.clone(),
1212 &example_progress,
1213 cx.clone(),
1214 )
1215 .await?;
1216 }
1217 Command::FormatPrompt(args) => {
1218 run_format_prompt(
1219 example,
1220 args,
1221 app_state.clone(),
1222 &example_progress,
1223 cx.clone(),
1224 )
1225 .await?;
1226 }
1227 Command::Predict(args) => {
1228 run_prediction(
1229 example,
1230 args,
1231 app_state.clone(),
1232 &example_progress,
1233 cx.clone(),
1234 )
1235 .await?;
1236 }
1237 Command::ParseOutput => {
1238 parse_output::run_parse_output(example)?;
1239 }
1240 Command::Distill => {
1241 run_distill(example).await?;
1242 }
1243 Command::Score(args) => {
1244 run_scoring(
1245 example,
1246 args,
1247 app_state.clone(),
1248 &example_progress,
1249 cx.clone(),
1250 )
1251 .await?;
1252 }
1253 Command::Eval(args) => {
1254 run_scoring(
1255 example,
1256 &args.predict,
1257 app_state.clone(),
1258 &example_progress,
1259 cx.clone(),
1260 )
1261 .await?;
1262 }
1263 Command::Qa(args) => {
1264 qa::run_qa(example, args, &example_progress).await?;
1265 }
1266 Command::Repair(args) => {
1267 repair::run_repair(example, args, &example_progress)
1268 .await?;
1269 }
1270 Command::Clean
1271 | Command::Synthesize(_)
1272 | Command::SplitCommit(_)
1273 | Command::Split(_)
1274 | Command::TruncatePatch(_)
1275 | Command::FilterLanguages(_)
1276 | Command::ImportBatch(_)
1277 | Command::PrintZetaFormats => {
1278 unreachable!()
1279 }
1280 }
1281 anyhow::Ok(())
1282 }
1283 .await;
1284
1285 let failed = if let Err(error) = result {
1286 handle_error(
1287 error,
1288 &args,
1289 &command,
1290 &app_state,
1291 failfast_on_single_example,
1292 &example,
1293 )
1294 .await;
1295 true
1296 } else {
1297 false
1298 };
1299
1300 let should_write = !failed || args.failed == FailedHandling::Keep;
1301 if should_write {
1302 if args.markdown {
1303 let markdown_dir =
1304 output.as_ref().expect("--markdown requires -o");
1305 let filename = format!("{}.md", example.spec.filename());
1306 let path = markdown_dir.join(&filename);
1307 let markdown = example.spec.to_markdown();
1308 std::fs::write(&path, &markdown)
1309 .expect("Failed to write markdown file");
1310 } else if let Some(ref mut sender) = output_sender.clone() {
1311 let line = serde_json::to_string(&example).unwrap();
1312 sender
1313 .send(line)
1314 .await
1315 .expect("Failed to send to output writer");
1316 } else if args.output.is_none()
1317 && !matches!(command, Command::Eval(_))
1318 {
1319 let line = serde_json::to_string(&example).unwrap();
1320 println!("{}", line);
1321 }
1322 }
1323 }
1324
1325 let project = repo_examples
1326 .iter()
1327 .find_map(|e| e.state.as_ref().map(|s| s.project.clone()));
1328
1329 if let Some(project) = project {
1330 let mut cx = cx.clone();
1331
1332 let shutdown_task: Task<()> =
1333 project.update(&mut cx, |project, cx| {
1334 let lsp_store = project.lsp_store();
1335 lsp_store.update(cx, |lsp_store, cx| {
1336 lsp_store.shutdown_all_language_servers(cx)
1337 })
1338 });
1339
1340 shutdown_task.await;
1341
1342 if let Some(ep_store) =
1343 cx.update(|cx| EditPredictionStore::try_global(cx))
1344 {
1345 ep_store.update(&mut cx, |store, _| {
1346 store.remove_project(&project);
1347 });
1348 }
1349 }
1350
1351 for example in &mut repo_examples {
1352 example.state.take();
1353 }
1354 finished_examples
1355 .lock()
1356 .unwrap()
1357 .extend_from_slice(&repo_examples);
1358 }
1359 });
1360 }
1361 futures::future::join_all(tasks).await;
1362
1363 Progress::global().finalize();
1364
1365 let is_markdown = args.markdown;
1366 let write_path = in_place_temp_path.as_ref().or(output.as_ref());
1367 match &command {
1368 Command::Predict(args) | Command::Score(args) => {
1369 predict::sync_batches(args.provider.as_ref()).await?;
1370 if args.wait {
1371 predict::wait_for_batches(args.provider.as_ref()).await?;
1372 let mut examples =
1373 std::mem::take(&mut *finished_examples.lock().unwrap());
1374 predict::reprocess_after_batch_wait(&mut examples, args).await?;
1375 rewrite_output(&examples, write_path, is_markdown)?;
1376 *finished_examples.lock().unwrap() = examples;
1377 }
1378 }
1379 Command::Eval(args) => {
1380 predict::sync_batches(args.predict.provider.as_ref()).await?;
1381 if args.predict.wait {
1382 predict::wait_for_batches(args.predict.provider.as_ref()).await?;
1383 let mut examples =
1384 std::mem::take(&mut *finished_examples.lock().unwrap());
1385 predict::reprocess_after_batch_wait(&mut examples, &args.predict)
1386 .await?;
1387 rewrite_output(&examples, write_path, is_markdown)?;
1388 *finished_examples.lock().unwrap() = examples;
1389 }
1390 }
1391 Command::Qa(args) => {
1392 qa::sync_batches(args).await?;
1393 }
1394 Command::Repair(args) => {
1395 repair::sync_batches(args).await?;
1396 if args.wait {
1397 repair::wait_for_batches(args).await?;
1398 let mut examples =
1399 std::mem::take(&mut *finished_examples.lock().unwrap());
1400 repair::reprocess_after_batch_wait(&mut examples, args).await?;
1401 rewrite_output(&examples, write_path, is_markdown)?;
1402 *finished_examples.lock().unwrap() = examples;
1403 }
1404 }
1405 _ => (),
1406 }
1407
1408 match &command {
1409 Command::Eval(args) => {
1410 let examples = finished_examples.lock().unwrap();
1411 score::print_report(&examples, args.verbose);
1412 if let Some(summary_path) = &args.summary_json {
1413 score::write_summary_json(&examples, summary_path)?;
1414 }
1415 }
1416 Command::Repair(args) => {
1417 let examples = finished_examples.lock().unwrap();
1418 repair::print_report(&examples, args.confidence_threshold);
1419 }
1420 _ => (),
1421 };
1422
1423 // For --in-place, atomically rename temp file to original
1424 if let Some(temp_path) = &in_place_temp_path {
1425 let final_path = output.as_ref().expect("in_place_temp_path requires output");
1426 std::fs::rename(temp_path, final_path)
1427 .expect("Failed to rename temp file to final output");
1428 }
1429
1430 anyhow::Ok(())
1431 }
1432 .await;
1433
1434 if let Err(e) = result {
1435 panic!("Fatal error: {:?}", e);
1436 }
1437
1438 let _ = cx.update(|cx| cx.quit());
1439 })
1440 .detach();
1441 });
1442}
1443
1444fn rewrite_output(
1445 examples: &[Example],
1446 output_path: Option<&PathBuf>,
1447 markdown: bool,
1448) -> anyhow::Result<()> {
1449 if markdown {
1450 let dir = output_path.context("--markdown requires -o")?;
1451 for example in examples {
1452 let filename = format!("{}.md", example.spec.filename());
1453 let path = dir.join(&filename);
1454 let markdown = example.spec.to_markdown();
1455 std::fs::write(&path, &markdown).context("Failed to write markdown file")?;
1456 }
1457 } else if let Some(path) = output_path {
1458 let file = OpenOptions::new()
1459 .create(true)
1460 .write(true)
1461 .truncate(true)
1462 .open(path)
1463 .context("Failed to open output file for rewriting")?;
1464 let mut writer = BufWriter::new(file);
1465 for example in examples {
1466 let line = serde_json::to_string(example)?;
1467 writeln!(writer, "{}", line)?;
1468 }
1469 writer.flush()?;
1470 } else {
1471 for example in examples {
1472 let line = serde_json::to_string(example)?;
1473 println!("{}", line);
1474 }
1475 }
1476 Ok(())
1477}
1478
1479async fn handle_error(
1480 error: anyhow::Error,
1481 args: &EpArgs,
1482 command: &Command,
1483 app_state: &Arc<headless::EpAppState>,
1484 failfast_on_single_example: bool,
1485 example: &Example,
1486) {
1487 Progress::global().increment_failed();
1488
1489 let msg;
1490 if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1491 let example_name = example.spec.filename();
1492
1493 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1494 app_state
1495 .fs
1496 .write(
1497 &failed_example_path,
1498 &serde_json::to_vec_pretty(&example).unwrap(),
1499 )
1500 .await
1501 .unwrap();
1502 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1503 app_state
1504 .fs
1505 .write(&err_path, format!("{error:?}").as_bytes())
1506 .await
1507 .unwrap();
1508
1509 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1510 let mut file = OpenOptions::new()
1511 .create(true)
1512 .append(true)
1513 .open(&failed_jsonl_path)
1514 .expect("Failed to open failed.jsonl");
1515 writeln!(file, "{}", serde_json::to_string(example).unwrap())
1516 .expect("Failed to write to failed.jsonl");
1517
1518 let cursor_path = match example.repo_name() {
1519 Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1520 Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1521 };
1522 msg = format!(
1523 indoc::indoc! {"
1524 While processing \"{}\":
1525
1526 \x1b[31m{:?}\x1b[0m
1527
1528 Example: \x1b[36m{}\x1b[0m
1529 Error file: \x1b[36m{}\x1b[0m
1530 Cursor file: \x1b[36m{}\x1b[0m
1531 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1532 "},
1533 example.spec.name,
1534 error,
1535 failed_example_path.display(),
1536 err_path.display(),
1537 cursor_path.display(),
1538 command,
1539 failed_example_path.display(),
1540 );
1541 } else {
1542 msg = format!(
1543 indoc::indoc! {"
1544 While processing \"{}\":
1545
1546 \x1b[31m{:?}\x1b[0m
1547 "},
1548 example.spec.name, error
1549 );
1550 }
1551
1552 if args.failfast || failfast_on_single_example {
1553 Progress::global().finalize();
1554 panic!("{}", msg);
1555 } else {
1556 log::error!("{}", msg);
1557 }
1558}