1mod anthropic_client;
2mod distill;
3mod example;
4mod filter_languages;
5mod format_prompt;
6mod git;
7mod headless;
8mod load_project;
9mod metrics;
10mod openai_client;
11mod parse_output;
12mod paths;
13mod predict;
14mod progress;
15mod prompt_assets;
16mod pull_examples;
17mod qa;
18mod reorder_patch;
19mod repair;
20mod retrieve_context;
21mod reversal_tracking;
22mod score;
23mod split_commit;
24mod split_dataset;
25mod synthesize;
26mod truncate_expected_patch;
27mod word_diff;
28use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
29use collections::HashSet;
30use edit_prediction::EditPredictionStore;
31use futures::channel::mpsc;
32use futures::{SinkExt as _, StreamExt as _};
33use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
34use zeta_prompt::ZetaVersion;
35
36use reqwest_client::ReqwestClient;
37use serde::{Deserialize, Deserializer, Serialize, Serializer};
38use std::fmt::Display;
39use std::fs::{File, OpenOptions};
40use std::hash::{Hash, Hasher};
41use std::io::{BufRead, BufReader, BufWriter, Write};
42use std::sync::Mutex;
43use std::{path::PathBuf, sync::Arc};
44
45use crate::distill::run_distill;
46use crate::example::{Example, group_examples_by_repo, read_example_files};
47use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
48use crate::format_prompt::run_format_prompt;
49use crate::load_project::run_load_project;
50use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
51use crate::predict::run_prediction;
52use crate::progress::Progress;
53use crate::retrieve_context::run_context_retrieval;
54use crate::score::run_scoring;
55use crate::split_commit::SplitCommitArgs;
56use crate::split_dataset::SplitArgs;
57use crate::synthesize::{SynthesizeConfig, run_synthesize};
58use crate::truncate_expected_patch::TruncatePatchArgs;
59
60#[derive(Parser, Debug)]
61#[command(name = "ep")]
62struct EpArgs {
63 #[arg(long, default_value_t = false)]
64 printenv: bool,
65 #[clap(long, default_value_t = 10, global = true)]
66 max_parallelism: usize,
67 /// The limit for the number of examples to process
68 /// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
69 #[clap(long, global = true)]
70 limit: Option<usize>,
71 #[clap(long, global = true)]
72 offset: Option<usize>,
73 /// Filter examples by name
74 #[clap(long, global = true)]
75 name: Option<String>,
76 /// Filter examples by repository
77 #[clap(long, global = true)]
78 repo: Option<String>,
79 #[command(subcommand)]
80 command: Option<Command>,
81 #[clap(global = true, help = INPUTS_HELP)]
82 inputs: Vec<PathBuf>,
83 #[arg(long, short, global = true)]
84 output: Option<PathBuf>,
85 #[arg(long, short, global = true)]
86 in_place: bool,
87 #[arg(long, short, global = true)]
88 failfast: bool,
89 /// How to handle failed examples in output: keep them or skip them.
90 /// Failed examples are always logged to the run's failed directory.
91 #[arg(long, global = true, default_value = "keep")]
92 failed: FailedHandling,
93 /// Output as markdown files instead of JSONL. When set, -o specifies a directory
94 /// where one .md file per example will be written (named after each example).
95 #[arg(long, short, global = true)]
96 markdown: bool,
97}
98
99/// Controls whether failed examples are included in the main output.
100/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
101#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
102pub enum FailedHandling {
103 /// Include failed examples in the main output (default)
104 #[default]
105 Keep,
106 /// Exclude failed examples from the main output
107 Skip,
108 /// Skip writing files
109 SkipNoFiles,
110}
111
112const INPUTS_HELP: &str = r#"
113Inputs can be file paths or special specifiers:
114
115 path
116 Path to an example(s) file (.md, .json, or .jsonl)
117
118 captured-after:{timestamp}
119 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
120 These are examples captured via the "Capture Edit Prediction Example" action.
121
122 rejected-after:{timestamp}
123 Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
124 These are predictions that were shown to users but rejected (useful for DPO training).
125
126 rated-after:{timestamp}
127 Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
128 These are predictions that users explicitly rated as positive or negative via the
129 rate completions modal. Only zeta2 predictions are included.
130 - Positive ratings: output becomes expected_patches
131 - Negative ratings: output becomes rejected_patch
132
133 rated-positive-after:{timestamp}
134 Same as rated-after, but only fetches positively rated predictions.
135
136 rated-negative-after:{timestamp}
137 Same as rated-after, but only fetches negatively rated predictions.
138
139 Required environment variables to connect to Snowflake:
140 EP_SNOWFLAKE_API_KEY
141 EP_SNOWFLAKE_BASE_URL
142
143 Optional:
144 EP_SNOWFLAKE_ROLE
145
146Examples:
147
148 # Read examples from a file
149 ep read examples.jsonl -o output.jsonl
150
151 # Read captured examples after a timestamp
152 ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
153
154 # Read rejected predictions for DPO training
155 ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
156
157 # Read user-rated predictions
158 ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
159
160 # Read only positively rated predictions
161 ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
162
163 # Read only negatively rated predictions
164 ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
165
166 # Mix multiple input sources
167 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
168"#;
169
170#[derive(Subcommand, Debug, Clone)]
171enum Command {
172 /// Read examples from files or fetch from Snowflake, output as .jsonl
173 Read,
174 /// Create git worktrees for each example and load file contents
175 LoadProject,
176 /// Retrieve context for input examples.
177 Context,
178 /// Generate a prompt string for a specific model
179 FormatPrompt(FormatPromptArgs),
180 /// Runs edit prediction
181 Predict(PredictArgs),
182 /// Parse model outputs (actual_output) into unified diffs (actual_patch).
183 /// Requires format-prompt to have been run first. Uses provider from prompt.
184 ParseOutput,
185 /// Computes a score based on actual and expected patches
186 Score(PredictArgs),
187 /// Prepares a distillation dataset by copying expected outputs to
188 /// predicted outputs and removing actual outputs and prompts.
189 Distill,
190 /// Print aggregated scores
191 Eval(EvalArgs),
192 /// Generate eval examples by analyzing git commits from a repository
193 Synthesize(SynthesizeArgs),
194 /// Remove git repositories and worktrees
195 Clean,
196 /// Generate an evaluation example by splitting a chronologically-ordered commit
197 SplitCommit(SplitCommitArgs),
198 /// Truncate expected patch by the given criteria
199 TruncatePatch(TruncatePatchArgs),
200 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
201 Split(SplitArgs),
202 /// Filter a JSONL dataset by programming language (based on cursor_path extension)
203 FilterLanguages(FilterLanguagesArgs),
204 /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
205 ImportBatch(ImportBatchArgs),
206 /// Assess the quality of predictions using LLM-as-a-judge
207 Qa(qa::QaArgs),
208 /// Repair predictions that received poor QA scores by generating improved predictions
209 Repair(repair::RepairArgs),
210}
211
212impl Display for Command {
213 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214 match self {
215 Command::Read => write!(f, "read"),
216 Command::LoadProject => write!(f, "load-project"),
217 Command::Context => write!(f, "context"),
218 Command::FormatPrompt(args) => {
219 write!(f, "format-prompt --provider={}", args.provider)
220 }
221 Command::Predict(args) => match &args.provider {
222 Some(provider) => write!(f, "predict --provider={}", provider),
223 None => write!(f, "predict"),
224 },
225 Command::ParseOutput => write!(f, "parse-output"),
226 Command::Score(args) => match &args.provider {
227 Some(provider) => write!(f, "score --provider={}", provider),
228 None => write!(f, "score"),
229 },
230 Command::Distill => write!(f, "distill"),
231 Command::Eval(args) => match &args.predict.provider {
232 Some(provider) => write!(f, "eval --provider={}", provider),
233 None => write!(f, "eval"),
234 },
235 Command::Synthesize(args) => {
236 write!(f, "synthesize --repos {}", args.repos.join(" "))
237 }
238 Command::Clean => write!(f, "clean"),
239 Command::SplitCommit(_) => write!(f, "split-commit"),
240 Command::TruncatePatch(_) => write!(f, "truncate-patch"),
241 Command::Split(_) => write!(f, "split"),
242 Command::FilterLanguages(_) => write!(f, "filter-languages"),
243 Command::ImportBatch(args) => {
244 write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
245 }
246 Command::Qa(_) => {
247 write!(f, "qa")
248 }
249 Command::Repair(_) => {
250 write!(f, "repair")
251 }
252 }
253 }
254}
255
256#[derive(Debug, Args, Clone)]
257struct FormatPromptArgs {
258 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
259 provider: PredictionProvider,
260}
261
262#[derive(Debug, Args, Clone)]
263struct PredictArgs {
264 #[clap(long, short('p'))]
265 provider: Option<PredictionProvider>,
266 #[clap(long, default_value_t = 1)]
267 repetitions: usize,
268 /// Only use cached responses, don't queue new requests for batching
269 #[clap(long)]
270 cache_only: bool,
271}
272
273#[derive(Debug, Args, Clone)]
274struct EvalArgs {
275 #[clap(flatten)]
276 predict: PredictArgs,
277 /// Path to write summary scores as JSON
278 #[clap(long)]
279 summary_json: Option<PathBuf>,
280}
281
282#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
283pub enum TeacherBackend {
284 Sonnet45,
285 Gpt52,
286}
287
288impl std::fmt::Display for TeacherBackend {
289 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
290 match self {
291 TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
292 TeacherBackend::Gpt52 => write!(f, "gpt52"),
293 }
294 }
295}
296
297impl std::str::FromStr for TeacherBackend {
298 type Err = anyhow::Error;
299
300 fn from_str(s: &str) -> Result<Self, Self::Err> {
301 match s.to_lowercase().as_str() {
302 "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
303 "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
304 "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
305 _ => anyhow::bail!("unknown teacher backend `{s}`. Valid options: sonnet45, gpt52"),
306 }
307 }
308}
309
310impl TeacherBackend {
311 pub fn model_name(&self) -> &'static str {
312 match self {
313 TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
314 TeacherBackend::Gpt52 => "gpt-5.2",
315 }
316 }
317}
318
319#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
320enum PredictionProvider {
321 Sweep,
322 Mercury,
323 Zeta1,
324 Zeta2(ZetaVersion),
325 Teacher(TeacherBackend),
326 TeacherNonBatching(TeacherBackend),
327 Repair,
328}
329
330impl Default for PredictionProvider {
331 fn default() -> Self {
332 PredictionProvider::Zeta2(ZetaVersion::default())
333 }
334}
335
336impl std::fmt::Display for PredictionProvider {
337 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
338 match self {
339 PredictionProvider::Sweep => write!(f, "sweep"),
340 PredictionProvider::Mercury => write!(f, "mercury"),
341 PredictionProvider::Zeta1 => write!(f, "zeta1"),
342 PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
343 PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
344 PredictionProvider::TeacherNonBatching(backend) => {
345 write!(f, "teacher-non-batching:{backend}")
346 }
347 PredictionProvider::Repair => write!(f, "repair"),
348 }
349 }
350}
351
352impl std::str::FromStr for PredictionProvider {
353 type Err = anyhow::Error;
354
355 fn from_str(s: &str) -> Result<Self, Self::Err> {
356 let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
357
358 let provider_lower = provider.to_lowercase();
359 match provider_lower.as_str() {
360 "sweep" => Ok(PredictionProvider::Sweep),
361 "mercury" => Ok(PredictionProvider::Mercury),
362 "zeta1" => Ok(PredictionProvider::Zeta1),
363 "zeta2" => {
364 let version = arg.map(ZetaVersion::parse).transpose()?.unwrap_or_default();
365 Ok(PredictionProvider::Zeta2(version))
366 }
367 "teacher" => {
368 let backend = arg
369 .map(|a| a.parse())
370 .transpose()?
371 .unwrap_or(TeacherBackend::Sonnet45);
372 Ok(PredictionProvider::Teacher(backend))
373 }
374 "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
375 let backend = arg
376 .map(|a| a.parse())
377 .transpose()?
378 .unwrap_or(TeacherBackend::Sonnet45);
379 Ok(PredictionProvider::TeacherNonBatching(backend))
380 }
381 "repair" => Ok(PredictionProvider::Repair),
382 _ => {
383 anyhow::bail!(
384 "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
385 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
386 For teacher, you can specify a backend like `teacher:sonnet45` or `teacher:gpt52`.\n\
387 Available zeta versions:\n{}",
388 ZetaVersion::options_as_string()
389 )
390 }
391 }
392 }
393}
394
395impl Serialize for PredictionProvider {
396 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
397 where
398 S: Serializer,
399 {
400 serializer.serialize_str(&self.to_string())
401 }
402}
403
404impl<'de> Deserialize<'de> for PredictionProvider {
405 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
406 where
407 D: Deserializer<'de>,
408 {
409 let s = String::deserialize(deserializer)?;
410 s.parse().map_err(serde::de::Error::custom)
411 }
412}
413
414#[derive(Debug, Args, Clone)]
415struct SynthesizeArgs {
416 /// Repository URLs (git@github.com:owner/repo or https://...)
417 #[clap(long, required = true, num_args = 1..)]
418 repos: Vec<String>,
419
420 /// Number of examples to generate per repository
421 #[clap(long, default_value_t = 5)]
422 count: usize,
423
424 /// Maximum commits to scan per repository before giving up
425 #[clap(long, default_value_t = 100)]
426 max_commits: usize,
427
428 /// Ignore state file and reprocess all commits
429 #[clap(long)]
430 fresh: bool,
431}
432
433#[derive(Debug, Args, Clone)]
434struct ImportBatchArgs {
435 /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
436 #[clap(long, required = true, num_args = 1..)]
437 batch_ids: Vec<String>,
438 /// Which provider's batches to import (anthropic or openai)
439 #[clap(long, default_value = "anthropic")]
440 provider: BatchProvider,
441}
442
443#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
444enum BatchProvider {
445 Anthropic,
446 Openai,
447}
448
449impl EpArgs {
450 fn output_path(&self) -> Option<PathBuf> {
451 if self.in_place {
452 if self.inputs.len() == 1 {
453 self.inputs.first().cloned()
454 } else {
455 panic!("--in-place requires exactly one input file")
456 }
457 } else {
458 self.output.clone()
459 }
460 }
461}
462
463async fn load_examples(
464 http_client: Arc<dyn http_client::HttpClient>,
465 args: &EpArgs,
466 output_path: Option<&PathBuf>,
467 background_executor: BackgroundExecutor,
468) -> anyhow::Result<Vec<Example>> {
469 let mut captured_after_timestamps = Vec::new();
470 let mut rejected_after_timestamps = Vec::new();
471 let mut requested_after_timestamps = Vec::new();
472 let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
473 Vec::new();
474 let mut file_inputs = Vec::new();
475
476 for input in &args.inputs {
477 let input_string = input.to_string_lossy();
478 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
479 captured_after_timestamps.push(timestamp.to_string());
480 } else if let Some(timestamp) =
481 pull_examples::parse_rejected_after_input(input_string.as_ref())
482 {
483 rejected_after_timestamps.push(timestamp.to_string());
484 } else if let Some(timestamp) =
485 pull_examples::parse_requested_after_input(input_string.as_ref())
486 {
487 requested_after_timestamps.push(timestamp.to_string());
488 } else if let Some((timestamp, rating_filter)) =
489 pull_examples::parse_rated_after_input(input_string.as_ref())
490 {
491 rated_after_inputs.push((timestamp.to_string(), rating_filter));
492 } else {
493 file_inputs.push(input.clone());
494 }
495 }
496
497 let mut examples = read_example_files(&file_inputs);
498
499 Progress::global().set_total_examples(examples.len());
500
501 let remaining_limit_for_snowflake =
502 args.limit.map(|limit| limit.saturating_sub(examples.len()));
503
504 if let Some(0) = remaining_limit_for_snowflake {
505 log::info!(
506 "skipping Snowflake inputs because --limit is already satisfied by example files"
507 );
508 } else {
509 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
510
511 if !captured_after_timestamps.is_empty() {
512 captured_after_timestamps.sort();
513
514 let mut captured_examples = pull_examples::fetch_captured_examples_after(
515 http_client.clone(),
516 &captured_after_timestamps,
517 max_rows_per_timestamp,
518 background_executor.clone(),
519 )
520 .await?;
521 examples.append(&mut captured_examples);
522 }
523
524 if !rejected_after_timestamps.is_empty() {
525 rejected_after_timestamps.sort();
526
527 let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
528 http_client.clone(),
529 &rejected_after_timestamps,
530 max_rows_per_timestamp,
531 background_executor.clone(),
532 )
533 .await?;
534 examples.append(&mut rejected_examples);
535 }
536
537 if !requested_after_timestamps.is_empty() {
538 requested_after_timestamps.sort();
539
540 let mut requested_examples = pull_examples::fetch_requested_examples_after(
541 http_client.clone(),
542 &requested_after_timestamps,
543 max_rows_per_timestamp,
544 background_executor.clone(),
545 )
546 .await?;
547 examples.append(&mut requested_examples);
548 }
549
550 if !rated_after_inputs.is_empty() {
551 rated_after_inputs.sort();
552
553 let mut rated_examples = pull_examples::fetch_rated_examples_after(
554 http_client,
555 &rated_after_inputs,
556 max_rows_per_timestamp,
557 background_executor,
558 )
559 .await?;
560 examples.append(&mut rated_examples);
561 }
562 }
563
564 crate::example::sort_examples_by_repo_and_rev(&mut examples);
565
566 if let Some(name_filter) = &args.name {
567 examples.retain(|example| example.spec.name.contains(name_filter));
568 }
569 if let Some(repo_filter) = &args.repo {
570 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
571 }
572
573 // Skip resume logic for --in-place since input and output are the same file,
574 // which would incorrectly treat all input examples as already processed.
575 if !args.in_place {
576 if let Some(path) = output_path
577 && let Some(command) = &args.command
578 {
579 resume_from_output(path, &mut examples, command);
580 }
581 }
582
583 if let Some(offset) = args.offset {
584 examples.splice(0..offset, []);
585 }
586
587 if let Some(limit) = args.limit {
588 examples.truncate(limit);
589 }
590
591 let progress = Progress::global();
592 progress.set_total_examples(examples.len());
593 progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
594
595 Ok(examples)
596}
597
598fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
599 let mut hasher = collections::FxHasher::default();
600 spec.hash(&mut hasher);
601 hasher.finish()
602}
603
604fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>, command: &Command) {
605 let file = match File::open(path) {
606 Ok(f) => f,
607 Err(_) => return,
608 };
609
610 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
611
612 let reader = BufReader::new(file);
613 let mut kept_lines = Vec::new();
614 let mut kept_hashes = HashSet::default();
615
616 for line in reader.lines() {
617 let line = match line {
618 Ok(l) => l,
619 Err(_) => continue,
620 };
621
622 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
623 let hash = spec_hash(&output_example.spec);
624 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
625 let is_complete = match command {
626 Command::Qa(_) => output_example
627 .qa
628 .first()
629 .and_then(|q| q.as_ref())
630 .and_then(|q| q.confidence)
631 .is_some(),
632 Command::Repair(_) => output_example.predictions.iter().any(|p| {
633 p.provider == PredictionProvider::Repair && p.actual_patch.is_some()
634 }),
635 _ => true,
636 };
637 if is_complete {
638 kept_hashes.insert(hash);
639 kept_lines.push(line);
640 }
641 }
642 }
643 }
644
645 let total = examples.len();
646 let already_processed = kept_hashes.len();
647
648 eprintln!(
649 "Resuming: {}/{} examples already processed",
650 already_processed, total
651 );
652
653 let file = OpenOptions::new()
654 .write(true)
655 .truncate(true)
656 .open(path)
657 .expect("Failed to open output file for rewriting");
658 let mut writer = BufWriter::new(file);
659 for line in &kept_lines {
660 writeln!(writer, "{}", line).expect("Failed to write to output file");
661 }
662 writer.flush().expect("Failed to flush output file");
663
664 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
665}
666
667fn main() {
668 let args = EpArgs::parse();
669
670 if args.printenv {
671 ::util::shell_env::print_env();
672 return;
673 }
674
675 let output = args.output_path();
676
677 if args.markdown && output.is_none() {
678 eprintln!("--markdown requires -o to specify the output directory");
679 std::process::exit(1);
680 }
681
682 let command = match &args.command {
683 Some(cmd) => cmd.clone(),
684 None => {
685 EpArgs::command().print_help().unwrap();
686 return;
687 }
688 };
689
690 match &command {
691 Command::ImportBatch(import_args) => {
692 smol::block_on(async {
693 match import_args.provider {
694 BatchProvider::Anthropic => {
695 let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
696 .expect("Failed to create Anthropic client");
697 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
698 eprintln!("Error importing Anthropic batches: {:?}", e);
699 std::process::exit(1);
700 }
701 }
702 BatchProvider::Openai => {
703 let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
704 .expect("Failed to create OpenAI client");
705 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
706 eprintln!("Error importing OpenAI batches: {:?}", e);
707 std::process::exit(1);
708 }
709 }
710 }
711 println!(
712 "Successfully imported {} batch(es)",
713 import_args.batch_ids.len()
714 );
715 });
716 return;
717 }
718 Command::Clean => {
719 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
720 return;
721 }
722 Command::Synthesize(synth_args) => {
723 let Some(output_dir) = args.output else {
724 panic!("output dir is required");
725 };
726 let config = SynthesizeConfig {
727 repo_urls: synth_args.repos.clone(),
728 count: synth_args.count,
729 max_commits: synth_args.max_commits,
730 output_dir,
731 fresh: synth_args.fresh,
732 };
733 smol::block_on(async {
734 if let Err(e) = run_synthesize(config).await {
735 eprintln!("Error: {:?}", e);
736 std::process::exit(1);
737 }
738 });
739 return;
740 }
741 Command::SplitCommit(split_commit_args) => {
742 if let Err(error) = split_commit::run_split_commit(
743 split_commit_args,
744 &args.inputs,
745 output.as_ref(),
746 args.failed,
747 ) {
748 eprintln!("{error:#}");
749 std::process::exit(1);
750 }
751 return;
752 }
753 Command::TruncatePatch(truncate_args) => {
754 if let Err(error) =
755 truncate_expected_patch::run_truncate_expected_patch(truncate_args, &args.inputs)
756 {
757 eprintln!("{error:#}");
758 std::process::exit(1);
759 }
760 return;
761 }
762 Command::Split(split_args) => {
763 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
764 eprintln!("{error:#}");
765 std::process::exit(1);
766 }
767 return;
768 }
769 Command::FilterLanguages(filter_args) => {
770 if let Err(error) =
771 run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
772 {
773 eprintln!("{error:#}");
774 std::process::exit(1);
775 }
776 return;
777 }
778
779 _ => {}
780 }
781
782 let http_client = Arc::new(ReqwestClient::new());
783 let app = Application::headless().with_http_client(http_client);
784
785 app.run(move |cx| {
786 let app_state = Arc::new(headless::init(cx));
787 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
788
789 cx.spawn(async move |cx| {
790 let result = async {
791 let examples = load_examples(
792 app_state.client.http_client(),
793 &args,
794 output.as_ref(),
795 cx.background_executor().clone(),
796 )
797 .await?;
798
799 match &command {
800 Command::Predict(args) | Command::Score(args) => {
801 predict::sync_batches(args.provider.as_ref()).await?;
802 }
803 Command::Eval(args) => {
804 predict::sync_batches(args.predict.provider.as_ref()).await?;
805 }
806 Command::Qa(args) => {
807 qa::sync_batches(args).await?;
808 }
809 Command::Repair(args) => {
810 repair::sync_batches(args).await?;
811 }
812 _ => (),
813 }
814
815 let failfast_on_single_example = examples.len() == 1;
816
817 // For --markdown mode, create the output directory if it doesn't exist
818 if args.markdown {
819 let dir = output.as_ref().expect("--markdown requires -o");
820 if !dir.exists() {
821 std::fs::create_dir_all(dir)
822 .expect("Failed to create markdown output directory");
823 }
824 }
825
826 // Set up JSONL output writer (not used in markdown mode)
827 let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
828 let mut in_place_temp_path: Option<PathBuf> = None;
829 if !args.markdown
830 && let Some(output_path) = output.as_ref()
831 {
832 let write_path = if args.in_place {
833 let temp = output_path.with_extension("jsonl.tmp");
834 in_place_temp_path = Some(temp.clone());
835 temp
836 } else {
837 output_path.clone()
838 };
839
840 let file = OpenOptions::new()
841 .create(true)
842 .write(true)
843 .truncate(args.in_place)
844 .append(!args.in_place)
845 .open(&write_path)
846 .expect("Failed to open output file");
847
848 let mut writer = BufWriter::new(file);
849 let (sender, mut receiver) = mpsc::unbounded::<String>();
850 cx.background_spawn(async move {
851 while let Some(line) = receiver.next().await {
852 writeln!(writer, "{}", line).expect("Failed to write example");
853 writer.flush().expect("Failed to flush output");
854 }
855 })
856 .detach();
857 output_sender = Some(sender);
858 }
859
860 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
861 let finished_examples = Mutex::new(Vec::new());
862
863 let mut tasks = Vec::new();
864 for _ in 0..args.max_parallelism {
865 tasks.push(async {
866 loop {
867 let Some(mut repo_examples) =
868 grouped_examples.lock().unwrap().pop_front()
869 else {
870 break;
871 };
872 for example in &mut repo_examples {
873 let example_progress =
874 Progress::global().start_group(&example.spec.name);
875
876 let result = async {
877 match &command {
878 Command::Read => {}
879 Command::LoadProject => {
880 run_load_project(
881 example,
882 app_state.clone(),
883 &example_progress,
884 cx.clone(),
885 )
886 .await?;
887 }
888 Command::Context => {
889 run_context_retrieval(
890 example,
891 app_state.clone(),
892 &example_progress,
893 cx.clone(),
894 )
895 .await?;
896 }
897 Command::FormatPrompt(args) => {
898 run_format_prompt(
899 example,
900 args,
901 app_state.clone(),
902 &example_progress,
903 cx.clone(),
904 )
905 .await?;
906 }
907 Command::Predict(args) => {
908 run_prediction(
909 example,
910 args,
911 app_state.clone(),
912 &example_progress,
913 cx.clone(),
914 )
915 .await?;
916 }
917 Command::ParseOutput => {
918 parse_output::run_parse_output(example)?;
919 }
920 Command::Distill => {
921 run_distill(example).await?;
922 }
923 Command::Score(args) => {
924 run_scoring(
925 example,
926 args,
927 app_state.clone(),
928 &example_progress,
929 cx.clone(),
930 )
931 .await?;
932 }
933 Command::Eval(args) => {
934 run_scoring(
935 example,
936 &args.predict,
937 app_state.clone(),
938 &example_progress,
939 cx.clone(),
940 )
941 .await?;
942 }
943 Command::Qa(args) => {
944 qa::run_qa(example, args, &example_progress).await?;
945 }
946 Command::Repair(args) => {
947 repair::run_repair(example, args, &example_progress)
948 .await?;
949 }
950 Command::Clean
951 | Command::Synthesize(_)
952 | Command::SplitCommit(_)
953 | Command::Split(_)
954 | Command::TruncatePatch(_)
955 | Command::FilterLanguages(_)
956 | Command::ImportBatch(_) => {
957 unreachable!()
958 }
959 }
960 anyhow::Ok(())
961 }
962 .await;
963
964 let failed = if let Err(error) = result {
965 handle_error(
966 error,
967 &args,
968 &command,
969 &app_state,
970 failfast_on_single_example,
971 &example,
972 )
973 .await;
974 true
975 } else {
976 false
977 };
978
979 let should_write = !failed || args.failed == FailedHandling::Keep;
980 if should_write {
981 if args.markdown {
982 let markdown_dir =
983 output.as_ref().expect("--markdown requires -o");
984 let filename = format!("{}.md", example.spec.filename());
985 let path = markdown_dir.join(&filename);
986 let markdown = example.spec.to_markdown();
987 std::fs::write(&path, &markdown)
988 .expect("Failed to write markdown file");
989 } else if let Some(ref mut sender) = output_sender.clone() {
990 let line = serde_json::to_string(&example).unwrap();
991 sender
992 .send(line)
993 .await
994 .expect("Failed to send to output writer");
995 } else if args.output.is_none()
996 && !matches!(command, Command::Eval(_))
997 {
998 let line = serde_json::to_string(&example).unwrap();
999 println!("{}", line);
1000 }
1001 }
1002 }
1003
1004 let repo_url = &repo_examples.first().unwrap().spec.repository_url;
1005 let project = repo_examples
1006 .iter()
1007 .find_map(|e| e.state.as_ref().map(|s| s.project.clone()))
1008 .or_else(|| app_state.project_cache.get(repo_url));
1009
1010 if let Some(project) = project {
1011 let mut cx = cx.clone();
1012
1013 let shutdown_task: Task<()> =
1014 project.update(&mut cx, |project, cx| {
1015 let lsp_store = project.lsp_store();
1016 lsp_store.update(cx, |lsp_store, cx| {
1017 lsp_store.shutdown_all_language_servers(cx)
1018 })
1019 });
1020
1021 shutdown_task.await;
1022
1023 if let Some(ep_store) =
1024 cx.update(|cx| EditPredictionStore::try_global(cx))
1025 {
1026 ep_store.update(&mut cx, |store, _| {
1027 store.remove_project(&project);
1028 });
1029 }
1030 }
1031
1032 app_state.project_cache.remove(repo_url);
1033 for example in &mut repo_examples {
1034 example.state.take();
1035 }
1036 finished_examples
1037 .lock()
1038 .unwrap()
1039 .extend_from_slice(&repo_examples);
1040 }
1041 });
1042 }
1043 futures::future::join_all(tasks).await;
1044
1045 Progress::global().finalize();
1046
1047 match &command {
1048 Command::Predict(args) | Command::Score(args) => {
1049 predict::sync_batches(args.provider.as_ref()).await?;
1050 }
1051 Command::Eval(args) => {
1052 predict::sync_batches(args.predict.provider.as_ref()).await?;
1053 }
1054 Command::Qa(args) => {
1055 qa::sync_batches(args).await?;
1056 }
1057 Command::Repair(args) => {
1058 repair::sync_batches(args).await?;
1059 }
1060 _ => (),
1061 }
1062
1063 match &command {
1064 Command::Eval(args) => {
1065 let examples = finished_examples.lock().unwrap();
1066 score::print_report(&examples);
1067 if let Some(summary_path) = &args.summary_json {
1068 score::write_summary_json(&examples, summary_path)?;
1069 }
1070 }
1071 _ => (),
1072 };
1073
1074 // For --in-place, atomically rename temp file to original
1075 if let Some(temp_path) = &in_place_temp_path {
1076 let final_path = output.as_ref().expect("in_place_temp_path requires output");
1077 std::fs::rename(temp_path, final_path)
1078 .expect("Failed to rename temp file to final output");
1079 }
1080
1081 anyhow::Ok(())
1082 }
1083 .await;
1084
1085 if let Err(e) = result {
1086 panic!("Fatal error: {:?}", e);
1087 }
1088
1089 let _ = cx.update(|cx| cx.quit());
1090 })
1091 .detach();
1092 });
1093}
1094
1095async fn handle_error(
1096 error: anyhow::Error,
1097 args: &EpArgs,
1098 command: &Command,
1099 app_state: &Arc<headless::EpAppState>,
1100 failfast_on_single_example: bool,
1101 example: &Example,
1102) {
1103 Progress::global().increment_failed();
1104
1105 let msg;
1106 if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1107 let example_name = example.spec.filename();
1108
1109 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1110 app_state
1111 .fs
1112 .write(
1113 &failed_example_path,
1114 &serde_json::to_vec_pretty(&example).unwrap(),
1115 )
1116 .await
1117 .unwrap();
1118 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1119 app_state
1120 .fs
1121 .write(&err_path, format!("{error:?}").as_bytes())
1122 .await
1123 .unwrap();
1124
1125 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1126 let mut file = OpenOptions::new()
1127 .create(true)
1128 .append(true)
1129 .open(&failed_jsonl_path)
1130 .expect("Failed to open failed.jsonl");
1131 writeln!(file, "{}", serde_json::to_string(example).unwrap())
1132 .expect("Failed to write to failed.jsonl");
1133
1134 let cursor_path = match example.repo_name() {
1135 Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1136 Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1137 };
1138 msg = format!(
1139 indoc::indoc! {"
1140 While processing \"{}\":
1141
1142 \x1b[31m{:?}\x1b[0m
1143
1144 Example: \x1b[36m{}\x1b[0m
1145 Error file: \x1b[36m{}\x1b[0m
1146 Cursor file: \x1b[36m{}\x1b[0m
1147 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1148 "},
1149 example.spec.name,
1150 error,
1151 failed_example_path.display(),
1152 err_path.display(),
1153 cursor_path.display(),
1154 command,
1155 failed_example_path.display(),
1156 );
1157 } else {
1158 msg = format!(
1159 indoc::indoc! {"
1160 While processing \"{}\":
1161
1162 \x1b[31m{:?}\x1b[0m
1163 "},
1164 example.spec.name, error
1165 );
1166 }
1167
1168 if args.failfast || failfast_on_single_example {
1169 Progress::global().finalize();
1170 panic!("{}", msg);
1171 } else {
1172 log::error!("{}", msg);
1173 }
1174}