1mod anthropic_client;
2mod distill;
3mod example;
4mod filter_languages;
5mod format_prompt;
6mod git;
7mod headless;
8mod load_project;
9mod metrics;
10mod openai_client;
11mod parse_output;
12mod paths;
13mod predict;
14mod progress;
15mod pull_examples;
16mod qa;
17mod reorder_patch;
18mod repair;
19mod retrieve_context;
20mod reversal_tracking;
21mod score;
22mod split_commit;
23mod split_dataset;
24mod synthesize;
25mod word_diff;
26use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
27use collections::HashSet;
28use edit_prediction::EditPredictionStore;
29use futures::channel::mpsc;
30use futures::{SinkExt as _, StreamExt as _};
31use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
32use zeta_prompt::ZetaVersion;
33
34use reqwest_client::ReqwestClient;
35use serde::{Deserialize, Deserializer, Serialize, Serializer};
36use std::fmt::Display;
37use std::fs::{File, OpenOptions};
38use std::hash::{Hash, Hasher};
39use std::io::{BufRead, BufReader, BufWriter, Write};
40use std::sync::Mutex;
41use std::{path::PathBuf, sync::Arc};
42
43use crate::distill::run_distill;
44use crate::example::{Example, group_examples_by_repo, read_example_files};
45use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
46use crate::format_prompt::run_format_prompt;
47use crate::load_project::run_load_project;
48use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
49use crate::predict::run_prediction;
50use crate::progress::Progress;
51use crate::retrieve_context::run_context_retrieval;
52use crate::score::run_scoring;
53use crate::split_commit::SplitCommitArgs;
54use crate::split_dataset::SplitArgs;
55use crate::synthesize::{SynthesizeConfig, run_synthesize};
56
57#[derive(Parser, Debug)]
58#[command(name = "ep")]
59struct EpArgs {
60 #[arg(long, default_value_t = false)]
61 printenv: bool,
62 #[clap(long, default_value_t = 10, global = true)]
63 max_parallelism: usize,
64 #[clap(long, global = true)]
65 limit: Option<usize>,
66 #[clap(long, global = true)]
67 offset: Option<usize>,
68 /// Filter examples by name
69 #[clap(long, global = true)]
70 name: Option<String>,
71 /// Filter examples by repository
72 #[clap(long, global = true)]
73 repo: Option<String>,
74 #[command(subcommand)]
75 command: Option<Command>,
76 #[clap(global = true, help = INPUTS_HELP)]
77 inputs: Vec<PathBuf>,
78 #[arg(long, short, global = true)]
79 output: Option<PathBuf>,
80 #[arg(long, short, global = true)]
81 in_place: bool,
82 #[arg(long, short, global = true)]
83 failfast: bool,
84 /// How to handle failed examples in output: keep them or skip them.
85 /// Failed examples are always logged to the run's failed directory.
86 #[arg(long, global = true, default_value = "keep")]
87 failed: FailedHandling,
88}
89
90/// Controls whether failed examples are included in the main output.
91/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
92#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
93pub enum FailedHandling {
94 /// Include failed examples in the main output (default)
95 #[default]
96 Keep,
97 /// Exclude failed examples from the main output
98 Skip,
99 /// Skip writing files
100 SkipNoFiles,
101}
102
103const INPUTS_HELP: &str = r#"
104Inputs can be file paths or special specifiers:
105
106 path
107 Path to an example(s) file (.md, .json, or .jsonl)
108
109 captured-after:{timestamp}
110 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
111 These are examples captured via the "Capture Edit Prediction Example" action.
112
113 rejected-after:{timestamp}
114 Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
115 These are predictions that were shown to users but rejected (useful for DPO training).
116
117 Required environment variables to connect to Snowflake:
118 EP_SNOWFLAKE_API_KEY
119 EP_SNOWFLAKE_BASE_URL
120
121 Optional:
122 EP_SNOWFLAKE_ROLE
123
124Examples:
125
126 # Read examples from a file
127 ep read examples.jsonl -o output.jsonl
128
129 # Read captured examples after a timestamp
130 ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
131
132 # Read rejected predictions for DPO training
133 ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
134
135 # Mix multiple input sources
136 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
137"#;
138
139#[derive(Subcommand, Debug, Clone)]
140enum Command {
141 /// Read examples from files or fetch from Snowflake, output as .jsonl
142 Read,
143 /// Create git worktrees for each example and load file contents
144 LoadProject,
145 /// Retrieve context for input examples.
146 Context,
147 /// Generate a prompt string for a specific model
148 FormatPrompt(FormatPromptArgs),
149 /// Runs edit prediction
150 Predict(PredictArgs),
151 /// Parse model outputs (actual_output) into unified diffs (actual_patch).
152 /// Requires format-prompt to have been run first. Uses provider from prompt.
153 ParseOutput,
154 /// Computes a score based on actual and expected patches
155 Score(PredictArgs),
156 /// Prepares a distillation dataset by copying expected outputs to
157 /// predicted outputs and removing actual outputs and prompts.
158 Distill,
159 /// Print aggregated scores
160 Eval(EvalArgs),
161 /// Generate eval examples by analyzing git commits from a repository
162 Synthesize(SynthesizeArgs),
163 /// Remove git repositories and worktrees
164 Clean,
165 /// Generate an evaluation example by splitting a chronologically-ordered commit
166 SplitCommit(SplitCommitArgs),
167 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
168 Split(SplitArgs),
169 /// Filter a JSONL dataset by programming language (based on cursor_path extension)
170 FilterLanguages(FilterLanguagesArgs),
171 /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
172 ImportBatch(ImportBatchArgs),
173 /// Assess the quality of predictions using LLM-as-a-judge
174 Qa(qa::QaArgs),
175 /// Repair predictions that received poor QA scores by generating improved predictions
176 Repair(repair::RepairArgs),
177}
178
179impl Display for Command {
180 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181 match self {
182 Command::Read => write!(f, "read"),
183 Command::LoadProject => write!(f, "load-project"),
184 Command::Context => write!(f, "context"),
185 Command::FormatPrompt(args) => {
186 write!(f, "format-prompt --provider={}", args.provider)
187 }
188 Command::Predict(args) => match &args.provider {
189 Some(provider) => write!(f, "predict --provider={}", provider),
190 None => write!(f, "predict"),
191 },
192 Command::ParseOutput => write!(f, "parse-output"),
193 Command::Score(args) => match &args.provider {
194 Some(provider) => write!(f, "score --provider={}", provider),
195 None => write!(f, "score"),
196 },
197 Command::Distill => write!(f, "distill"),
198 Command::Eval(args) => match &args.predict.provider {
199 Some(provider) => write!(f, "eval --provider={}", provider),
200 None => write!(f, "eval"),
201 },
202 Command::Synthesize(args) => {
203 write!(f, "synthesize --repos {}", args.repos.join(" "))
204 }
205 Command::Clean => write!(f, "clean"),
206 Command::SplitCommit(_) => write!(f, "split-commit"),
207 Command::Split(_) => write!(f, "split"),
208 Command::FilterLanguages(_) => write!(f, "filter-languages"),
209 Command::ImportBatch(args) => {
210 write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
211 }
212 Command::Qa(_) => {
213 write!(f, "qa")
214 }
215 Command::Repair(_) => {
216 write!(f, "repair")
217 }
218 }
219 }
220}
221
222#[derive(Debug, Args, Clone)]
223struct FormatPromptArgs {
224 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
225 provider: PredictionProvider,
226}
227
228#[derive(Debug, Args, Clone)]
229struct PredictArgs {
230 #[clap(long, short('p'))]
231 provider: Option<PredictionProvider>,
232 #[clap(long, default_value_t = 1)]
233 repetitions: usize,
234}
235
236#[derive(Debug, Args, Clone)]
237struct EvalArgs {
238 #[clap(flatten)]
239 predict: PredictArgs,
240 /// Path to write summary scores as JSON
241 #[clap(long)]
242 summary_json: Option<PathBuf>,
243}
244
245#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
246pub enum TeacherBackend {
247 Sonnet45,
248 Gpt52,
249}
250
251impl std::fmt::Display for TeacherBackend {
252 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253 match self {
254 TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
255 TeacherBackend::Gpt52 => write!(f, "gpt52"),
256 }
257 }
258}
259
260impl std::str::FromStr for TeacherBackend {
261 type Err = anyhow::Error;
262
263 fn from_str(s: &str) -> Result<Self, Self::Err> {
264 match s.to_lowercase().as_str() {
265 "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
266 "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
267 "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
268 _ => anyhow::bail!("unknown teacher backend `{s}`. Valid options: sonnet45, gpt52"),
269 }
270 }
271}
272
273impl TeacherBackend {
274 pub fn model_name(&self) -> &'static str {
275 match self {
276 TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
277 TeacherBackend::Gpt52 => "gpt-5.2",
278 }
279 }
280}
281
282#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
283enum PredictionProvider {
284 Sweep,
285 Mercury,
286 Zeta1,
287 Zeta2(ZetaVersion),
288 Teacher(TeacherBackend),
289 TeacherNonBatching(TeacherBackend),
290 Repair,
291}
292
293impl Default for PredictionProvider {
294 fn default() -> Self {
295 PredictionProvider::Zeta2(ZetaVersion::default())
296 }
297}
298
299impl std::fmt::Display for PredictionProvider {
300 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301 match self {
302 PredictionProvider::Sweep => write!(f, "sweep"),
303 PredictionProvider::Mercury => write!(f, "mercury"),
304 PredictionProvider::Zeta1 => write!(f, "zeta1"),
305 PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
306 PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
307 PredictionProvider::TeacherNonBatching(backend) => {
308 write!(f, "teacher-non-batching:{backend}")
309 }
310 PredictionProvider::Repair => write!(f, "repair"),
311 }
312 }
313}
314
315impl std::str::FromStr for PredictionProvider {
316 type Err = anyhow::Error;
317
318 fn from_str(s: &str) -> Result<Self, Self::Err> {
319 let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
320
321 let provider_lower = provider.to_lowercase();
322 match provider_lower.as_str() {
323 "sweep" => Ok(PredictionProvider::Sweep),
324 "mercury" => Ok(PredictionProvider::Mercury),
325 "zeta1" => Ok(PredictionProvider::Zeta1),
326 "zeta2" => {
327 let version = arg.map(ZetaVersion::parse).transpose()?.unwrap_or_default();
328 Ok(PredictionProvider::Zeta2(version))
329 }
330 "teacher" => {
331 let backend = arg
332 .map(|a| a.parse())
333 .transpose()?
334 .unwrap_or(TeacherBackend::Sonnet45);
335 Ok(PredictionProvider::Teacher(backend))
336 }
337 "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
338 let backend = arg
339 .map(|a| a.parse())
340 .transpose()?
341 .unwrap_or(TeacherBackend::Sonnet45);
342 Ok(PredictionProvider::TeacherNonBatching(backend))
343 }
344 "repair" => Ok(PredictionProvider::Repair),
345 _ => {
346 anyhow::bail!(
347 "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
348 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
349 For teacher, you can specify a backend like `teacher:sonnet45` or `teacher:gpt52`.\n\
350 Available zeta versions:\n{}",
351 ZetaVersion::options_as_string()
352 )
353 }
354 }
355 }
356}
357
358impl Serialize for PredictionProvider {
359 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
360 where
361 S: Serializer,
362 {
363 serializer.serialize_str(&self.to_string())
364 }
365}
366
367impl<'de> Deserialize<'de> for PredictionProvider {
368 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
369 where
370 D: Deserializer<'de>,
371 {
372 let s = String::deserialize(deserializer)?;
373 s.parse().map_err(serde::de::Error::custom)
374 }
375}
376
377#[derive(Debug, Args, Clone)]
378struct SynthesizeArgs {
379 /// Repository URLs (git@github.com:owner/repo or https://...)
380 #[clap(long, required = true, num_args = 1..)]
381 repos: Vec<String>,
382
383 /// Number of examples to generate per repository
384 #[clap(long, default_value_t = 5)]
385 count: usize,
386
387 /// Maximum commits to scan per repository before giving up
388 #[clap(long, default_value_t = 100)]
389 max_commits: usize,
390
391 /// Ignore state file and reprocess all commits
392 #[clap(long)]
393 fresh: bool,
394}
395
396#[derive(Debug, Args, Clone)]
397struct ImportBatchArgs {
398 /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
399 #[clap(long, required = true, num_args = 1..)]
400 batch_ids: Vec<String>,
401 /// Which provider's batches to import (anthropic or openai)
402 #[clap(long, default_value = "anthropic")]
403 provider: BatchProvider,
404}
405
406#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
407enum BatchProvider {
408 Anthropic,
409 Openai,
410}
411
412impl EpArgs {
413 fn output_path(&self) -> Option<PathBuf> {
414 if self.in_place {
415 if self.inputs.len() == 1 {
416 self.inputs.first().cloned()
417 } else {
418 panic!("--in-place requires exactly one input file")
419 }
420 } else {
421 self.output.clone()
422 }
423 }
424}
425
426async fn load_examples(
427 http_client: Arc<dyn http_client::HttpClient>,
428 args: &EpArgs,
429 output_path: Option<&PathBuf>,
430 background_executor: BackgroundExecutor,
431) -> anyhow::Result<Vec<Example>> {
432 let mut captured_after_timestamps = Vec::new();
433 let mut rejected_after_timestamps = Vec::new();
434 let mut file_inputs = Vec::new();
435
436 for input in &args.inputs {
437 let input_string = input.to_string_lossy();
438 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
439 captured_after_timestamps.push(timestamp.to_string());
440 } else if let Some(timestamp) =
441 pull_examples::parse_rejected_after_input(input_string.as_ref())
442 {
443 rejected_after_timestamps.push(timestamp.to_string());
444 } else {
445 file_inputs.push(input.clone());
446 }
447 }
448
449 let mut examples = read_example_files(&file_inputs);
450
451 Progress::global().set_total_examples(examples.len());
452
453 let remaining_limit_for_snowflake =
454 args.limit.map(|limit| limit.saturating_sub(examples.len()));
455
456 if let Some(0) = remaining_limit_for_snowflake {
457 log::info!(
458 "skipping Snowflake inputs because --limit is already satisfied by example files"
459 );
460 } else {
461 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
462
463 if !captured_after_timestamps.is_empty() {
464 captured_after_timestamps.sort();
465
466 let mut captured_examples = pull_examples::fetch_captured_examples_after(
467 http_client.clone(),
468 &captured_after_timestamps,
469 max_rows_per_timestamp,
470 background_executor.clone(),
471 )
472 .await?;
473 examples.append(&mut captured_examples);
474 }
475
476 if !rejected_after_timestamps.is_empty() {
477 rejected_after_timestamps.sort();
478
479 let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
480 http_client,
481 &rejected_after_timestamps,
482 max_rows_per_timestamp,
483 background_executor,
484 )
485 .await?;
486 examples.append(&mut rejected_examples);
487 }
488 }
489
490 crate::example::sort_examples_by_repo_and_rev(&mut examples);
491
492 if let Some(name_filter) = &args.name {
493 examples.retain(|example| example.spec.name.contains(name_filter));
494 }
495 if let Some(repo_filter) = &args.repo {
496 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
497 }
498
499 // Skip resume logic for --in-place since input and output are the same file,
500 // which would incorrectly treat all input examples as already processed.
501 if !args.in_place {
502 if let Some(path) = output_path {
503 resume_from_output(path, &mut examples);
504 }
505 }
506
507 if let Some(offset) = args.offset {
508 examples.splice(0..offset, []);
509 }
510
511 if let Some(limit) = args.limit {
512 examples.truncate(limit);
513 }
514
515 let progress = Progress::global();
516 progress.set_total_examples(examples.len());
517 progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
518
519 Ok(examples)
520}
521
522fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
523 let mut hasher = collections::FxHasher::default();
524 spec.hash(&mut hasher);
525 hasher.finish()
526}
527
528fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
529 let file = match File::open(path) {
530 Ok(f) => f,
531 Err(_) => return,
532 };
533
534 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
535
536 let reader = BufReader::new(file);
537 let mut kept_lines = Vec::new();
538 let mut kept_hashes = HashSet::default();
539
540 for line in reader.lines() {
541 let line = match line {
542 Ok(l) => l,
543 Err(_) => continue,
544 };
545
546 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
547 let hash = spec_hash(&output_example.spec);
548 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
549 kept_hashes.insert(hash);
550 kept_lines.push(line);
551 }
552 }
553 }
554
555 let total = examples.len();
556 let already_processed = kept_hashes.len();
557
558 eprintln!(
559 "Resuming: {}/{} examples already processed",
560 already_processed, total
561 );
562
563 let file = OpenOptions::new()
564 .write(true)
565 .truncate(true)
566 .open(path)
567 .expect("Failed to open output file for rewriting");
568 let mut writer = BufWriter::new(file);
569 for line in &kept_lines {
570 writeln!(writer, "{}", line).expect("Failed to write to output file");
571 }
572 writer.flush().expect("Failed to flush output file");
573
574 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
575}
576
577fn main() {
578 let args = EpArgs::parse();
579
580 if args.printenv {
581 ::util::shell_env::print_env();
582 return;
583 }
584
585 let output = args.output_path();
586 let command = match &args.command {
587 Some(cmd) => cmd.clone(),
588 None => {
589 EpArgs::command().print_help().unwrap();
590 return;
591 }
592 };
593
594 match &command {
595 Command::ImportBatch(import_args) => {
596 smol::block_on(async {
597 match import_args.provider {
598 BatchProvider::Anthropic => {
599 let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
600 .expect("Failed to create Anthropic client");
601 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
602 eprintln!("Error importing Anthropic batches: {:?}", e);
603 std::process::exit(1);
604 }
605 }
606 BatchProvider::Openai => {
607 let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
608 .expect("Failed to create OpenAI client");
609 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
610 eprintln!("Error importing OpenAI batches: {:?}", e);
611 std::process::exit(1);
612 }
613 }
614 }
615 println!(
616 "Successfully imported {} batch(es)",
617 import_args.batch_ids.len()
618 );
619 });
620 return;
621 }
622 Command::Clean => {
623 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
624 return;
625 }
626 Command::Synthesize(synth_args) => {
627 let Some(output_dir) = args.output else {
628 panic!("output dir is required");
629 };
630 let config = SynthesizeConfig {
631 repo_urls: synth_args.repos.clone(),
632 count: synth_args.count,
633 max_commits: synth_args.max_commits,
634 output_dir,
635 fresh: synth_args.fresh,
636 };
637 smol::block_on(async {
638 if let Err(e) = run_synthesize(config).await {
639 eprintln!("Error: {:?}", e);
640 std::process::exit(1);
641 }
642 });
643 return;
644 }
645 Command::SplitCommit(split_commit_args) => {
646 if let Err(error) = split_commit::run_split_commit(
647 split_commit_args,
648 &args.inputs,
649 output.as_ref(),
650 args.failed,
651 ) {
652 eprintln!("{error:#}");
653 std::process::exit(1);
654 }
655 return;
656 }
657 Command::Split(split_args) => {
658 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
659 eprintln!("{error:#}");
660 std::process::exit(1);
661 }
662 return;
663 }
664 Command::FilterLanguages(filter_args) => {
665 if let Err(error) =
666 run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
667 {
668 eprintln!("{error:#}");
669 std::process::exit(1);
670 }
671 return;
672 }
673 Command::Qa(qa_args) => {
674 // Read examples from input files
675 let mut examples = example::read_example_files(&args.inputs);
676
677 // Apply filters
678 if let Some(name_filter) = &args.name {
679 examples.retain(|e| e.spec.name.contains(name_filter));
680 }
681 if let Some(repo_filter) = &args.repo {
682 examples.retain(|e| e.spec.repository_url.contains(repo_filter));
683 }
684 if let Some(offset) = args.offset {
685 examples.splice(0..offset, []);
686 }
687 if let Some(limit) = args.limit {
688 examples.truncate(limit);
689 }
690
691 smol::block_on(async {
692 if let Err(e) = qa::run_qa(&mut examples, qa_args, output.as_ref()).await {
693 eprintln!("Error: {:?}", e);
694 std::process::exit(1);
695 }
696 });
697 return;
698 }
699 Command::Repair(repair_args) => {
700 // Read examples from input files
701 let mut examples = example::read_example_files(&args.inputs);
702
703 // Apply filters
704 if let Some(name_filter) = &args.name {
705 examples.retain(|e| e.spec.name.contains(name_filter));
706 }
707 if let Some(repo_filter) = &args.repo {
708 examples.retain(|e| e.spec.repository_url.contains(repo_filter));
709 }
710 if let Some(offset) = args.offset {
711 examples.splice(0..offset, []);
712 }
713 if let Some(limit) = args.limit {
714 examples.truncate(limit);
715 }
716
717 smol::block_on(async {
718 if let Err(e) =
719 repair::run_repair(&mut examples, repair_args, output.as_ref()).await
720 {
721 eprintln!("Error: {:?}", e);
722 std::process::exit(1);
723 }
724 });
725 return;
726 }
727 _ => {}
728 }
729
730 let http_client = Arc::new(ReqwestClient::new());
731 let app = Application::headless().with_http_client(http_client);
732
733 app.run(move |cx| {
734 let app_state = Arc::new(headless::init(cx));
735 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
736
737 cx.spawn(async move |cx| {
738 let result = async {
739 let examples = load_examples(
740 app_state.client.http_client(),
741 &args,
742 output.as_ref(),
743 cx.background_executor().clone(),
744 )
745 .await?;
746
747 match &command {
748 Command::Predict(args) | Command::Score(args) => {
749 predict::sync_batches(args.provider.as_ref()).await?;
750 }
751 Command::Eval(args) => {
752 predict::sync_batches(args.predict.provider.as_ref()).await?;
753 }
754 _ => (),
755 }
756
757 let failfast_on_single_example = examples.len() == 1;
758
759 // For --in-place, write to a temp file and rename at the end to avoid data loss on interruption
760 let in_place_temp_path = if args.in_place {
761 output.as_ref().map(|path| {
762 let mut temp_path = path.clone();
763 temp_path.set_extension("jsonl.tmp");
764 temp_path
765 })
766 } else {
767 None
768 };
769
770 let output_sender: Option<mpsc::UnboundedSender<String>> =
771 if args.output.is_some() || !matches!(command, Command::Eval(_)) {
772 let write_path = in_place_temp_path.as_ref().or(output.as_ref());
773 write_path.map(|path| {
774 let file = if args.in_place {
775 // For --in-place, write to temp file (truncate if exists)
776 OpenOptions::new()
777 .create(true)
778 .write(true)
779 .truncate(true)
780 .open(path)
781 .expect("Failed to open temp output file")
782 } else {
783 // For regular output, append to support resuming
784 OpenOptions::new()
785 .create(true)
786 .append(true)
787 .open(path)
788 .expect("Failed to open output file")
789 };
790 let mut writer = BufWriter::new(file);
791 let (sender, mut receiver) = mpsc::unbounded::<String>();
792 cx.background_spawn(async move {
793 while let Some(line) = receiver.next().await {
794 writeln!(writer, "{}", line).expect("Failed to write example");
795 writer.flush().expect("Failed to flush output");
796 }
797 })
798 .detach();
799 sender
800 })
801 } else {
802 None
803 };
804
805 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
806 let finished_examples = Mutex::new(Vec::new());
807
808 let mut tasks = Vec::new();
809 for _ in 0..args.max_parallelism {
810 tasks.push(async {
811 loop {
812 let Some(mut repo_examples) =
813 grouped_examples.lock().unwrap().pop_front()
814 else {
815 break;
816 };
817 for example in &mut repo_examples {
818 let example_progress =
819 Progress::global().start_group(&example.spec.name);
820
821 let result = async {
822 match &command {
823 Command::Read => {}
824 Command::LoadProject => {
825 run_load_project(
826 example,
827 app_state.clone(),
828 &example_progress,
829 cx.clone(),
830 )
831 .await?;
832 }
833 Command::Context => {
834 run_context_retrieval(
835 example,
836 app_state.clone(),
837 &example_progress,
838 cx.clone(),
839 )
840 .await?;
841 }
842 Command::FormatPrompt(args) => {
843 run_format_prompt(
844 example,
845 args,
846 app_state.clone(),
847 &example_progress,
848 cx.clone(),
849 )
850 .await?;
851 }
852 Command::Predict(args) => {
853 run_prediction(
854 example,
855 args,
856 app_state.clone(),
857 &example_progress,
858 cx.clone(),
859 )
860 .await?;
861 }
862 Command::ParseOutput => {
863 parse_output::run_parse_output(example)?;
864 }
865 Command::Distill => {
866 run_distill(example).await?;
867 }
868 Command::Score(args) => {
869 run_scoring(
870 example,
871 args,
872 app_state.clone(),
873 &example_progress,
874 cx.clone(),
875 )
876 .await?;
877 }
878 Command::Eval(args) => {
879 run_scoring(
880 example,
881 &args.predict,
882 app_state.clone(),
883 &example_progress,
884 cx.clone(),
885 )
886 .await?;
887 }
888 Command::Clean
889 | Command::Synthesize(_)
890 | Command::SplitCommit(_)
891 | Command::Split(_)
892 | Command::FilterLanguages(_)
893 | Command::ImportBatch(_)
894 | Command::Qa(_)
895 | Command::Repair(_) => {
896 unreachable!()
897 }
898 }
899 anyhow::Ok(())
900 }
901 .await;
902
903 let failed = if let Err(error) = result {
904 handle_error(
905 error,
906 &args,
907 &command,
908 &app_state,
909 failfast_on_single_example,
910 &example,
911 )
912 .await;
913 true
914 } else {
915 false
916 };
917
918 let should_write = !failed || args.failed == FailedHandling::Keep;
919 if should_write {
920 if let Some(ref mut sender) = output_sender.clone() {
921 let line = serde_json::to_string(&example).unwrap();
922 sender
923 .send(line)
924 .await
925 .expect("Failed to send to output writer");
926 } else if args.output.is_none()
927 && !matches!(command, Command::Eval(_))
928 {
929 let line = serde_json::to_string(&example).unwrap();
930 println!("{}", line);
931 }
932 }
933 }
934
935 let repo_url = &repo_examples.first().unwrap().spec.repository_url;
936 let project = repo_examples
937 .iter()
938 .find_map(|e| e.state.as_ref().map(|s| s.project.clone()))
939 .or_else(|| app_state.project_cache.get(repo_url));
940
941 if let Some(project) = project {
942 let mut cx = cx.clone();
943
944 let shutdown_task: Task<()> =
945 project.update(&mut cx, |project, cx| {
946 let lsp_store = project.lsp_store();
947 lsp_store.update(cx, |lsp_store, cx| {
948 lsp_store.shutdown_all_language_servers(cx)
949 })
950 });
951
952 shutdown_task.await;
953
954 if let Some(ep_store) =
955 cx.update(|cx| EditPredictionStore::try_global(cx))
956 {
957 ep_store.update(&mut cx, |store, _| {
958 store.remove_project(&project);
959 });
960 }
961 }
962
963 app_state.project_cache.remove(repo_url);
964 for example in &mut repo_examples {
965 example.state.take();
966 }
967 finished_examples
968 .lock()
969 .unwrap()
970 .extend_from_slice(&repo_examples);
971 }
972 });
973 }
974 futures::future::join_all(tasks).await;
975
976 Progress::global().finalize();
977
978 match &command {
979 Command::Predict(args) | Command::Score(args) => {
980 predict::sync_batches(args.provider.as_ref()).await?;
981 }
982 Command::Eval(args) => {
983 predict::sync_batches(args.predict.provider.as_ref()).await?;
984 }
985 _ => (),
986 }
987
988 match &command {
989 Command::Eval(args) => {
990 let examples = finished_examples.lock().unwrap();
991 score::print_report(&examples);
992 if let Some(summary_path) = &args.summary_json {
993 score::write_summary_json(&examples, summary_path)?;
994 }
995 }
996 _ => (),
997 };
998
999 // For --in-place, atomically rename temp file to original
1000 if let (Some(temp_path), Some(final_path)) = (&in_place_temp_path, &output) {
1001 std::fs::rename(temp_path, final_path)
1002 .expect("Failed to rename temp file to final output");
1003 }
1004
1005 anyhow::Ok(())
1006 }
1007 .await;
1008
1009 if let Err(e) = result {
1010 panic!("Fatal error: {:?}", e);
1011 }
1012
1013 let _ = cx.update(|cx| cx.quit());
1014 })
1015 .detach();
1016 });
1017}
1018
1019async fn handle_error(
1020 error: anyhow::Error,
1021 args: &EpArgs,
1022 command: &Command,
1023 app_state: &Arc<headless::EpAppState>,
1024 failfast_on_single_example: bool,
1025 example: &Example,
1026) {
1027 Progress::global().increment_failed();
1028
1029 let msg;
1030 if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1031 let example_name = example.spec.filename();
1032
1033 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1034 app_state
1035 .fs
1036 .write(
1037 &failed_example_path,
1038 &serde_json::to_vec_pretty(&example).unwrap(),
1039 )
1040 .await
1041 .unwrap();
1042 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1043 app_state
1044 .fs
1045 .write(&err_path, format!("{error:?}").as_bytes())
1046 .await
1047 .unwrap();
1048
1049 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1050 let mut file = OpenOptions::new()
1051 .create(true)
1052 .append(true)
1053 .open(&failed_jsonl_path)
1054 .expect("Failed to open failed.jsonl");
1055 writeln!(file, "{}", serde_json::to_string(example).unwrap())
1056 .expect("Failed to write to failed.jsonl");
1057
1058 let cursor_path = example
1059 .repo_name()
1060 .unwrap()
1061 .worktree_path()
1062 .join(&example.spec.cursor_path);
1063 msg = format!(
1064 indoc::indoc! {"
1065 While processing \"{}\":
1066
1067 \x1b[31m{:?}\x1b[0m
1068
1069 Example: \x1b[36m{}\x1b[0m
1070 Error file: \x1b[36m{}\x1b[0m
1071 Cursor file: \x1b[36m{}\x1b[0m
1072 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1073 "},
1074 example.spec.name,
1075 error,
1076 failed_example_path.display(),
1077 err_path.display(),
1078 cursor_path.display(),
1079 command,
1080 failed_example_path.display(),
1081 );
1082 } else {
1083 msg = format!(
1084 indoc::indoc! {"
1085 While processing \"{}\":
1086
1087 \x1b[31m{:?}\x1b[0m
1088 "},
1089 example.spec.name, error
1090 );
1091 }
1092
1093 if args.failfast || failfast_on_single_example {
1094 Progress::global().finalize();
1095 panic!("{}", msg);
1096 } else {
1097 log::error!("{}", msg);
1098 }
1099}