1mod anthropic_client;
2mod distill;
3mod example;
4mod filter_languages;
5mod format_prompt;
6mod git;
7mod headless;
8mod load_project;
9mod metrics;
10mod openai_client;
11mod parse_output;
12mod paths;
13mod predict;
14mod progress;
15mod pull_examples;
16mod qa;
17mod reorder_patch;
18mod repair;
19mod retrieve_context;
20mod reversal_tracking;
21mod score;
22mod split_commit;
23mod split_dataset;
24mod synthesize;
25mod word_diff;
26use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
27use collections::HashSet;
28use edit_prediction::EditPredictionStore;
29use futures::channel::mpsc;
30use futures::{SinkExt as _, StreamExt as _};
31use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
32use zeta_prompt::ZetaVersion;
33
34use reqwest_client::ReqwestClient;
35use serde::{Deserialize, Deserializer, Serialize, Serializer};
36use std::fmt::Display;
37use std::fs::{File, OpenOptions};
38use std::hash::{Hash, Hasher};
39use std::io::{BufRead, BufReader, BufWriter, Write};
40use std::sync::Mutex;
41use std::{path::PathBuf, sync::Arc};
42
43use crate::distill::run_distill;
44use crate::example::{Example, group_examples_by_repo, read_example_files};
45use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
46use crate::format_prompt::run_format_prompt;
47use crate::load_project::run_load_project;
48use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
49use crate::predict::run_prediction;
50use crate::progress::Progress;
51use crate::retrieve_context::run_context_retrieval;
52use crate::score::run_scoring;
53use crate::split_commit::SplitCommitArgs;
54use crate::split_dataset::SplitArgs;
55use crate::synthesize::{SynthesizeConfig, run_synthesize};
56
57#[derive(Parser, Debug)]
58#[command(name = "ep")]
59struct EpArgs {
60 #[arg(long, default_value_t = false)]
61 printenv: bool,
62 #[clap(long, default_value_t = 10, global = true)]
63 max_parallelism: usize,
64 #[clap(long, global = true)]
65 limit: Option<usize>,
66 #[clap(long, global = true)]
67 offset: Option<usize>,
68 /// Filter examples by name
69 #[clap(long, global = true)]
70 name: Option<String>,
71 /// Filter examples by repository
72 #[clap(long, global = true)]
73 repo: Option<String>,
74 #[command(subcommand)]
75 command: Option<Command>,
76 #[clap(global = true, help = INPUTS_HELP)]
77 inputs: Vec<PathBuf>,
78 #[arg(long, short, global = true)]
79 output: Option<PathBuf>,
80 #[arg(long, short, global = true)]
81 in_place: bool,
82 #[arg(long, short, global = true)]
83 failfast: bool,
84 /// How to handle failed examples in output: keep them or skip them.
85 /// Failed examples are always logged to the run's failed directory.
86 #[arg(long, global = true, default_value = "keep")]
87 failed: FailedHandling,
88 /// Output as markdown files instead of JSONL. When set, -o specifies a directory
89 /// where one .md file per example will be written (named after each example).
90 #[arg(long, short, global = true)]
91 markdown: bool,
92}
93
94/// Controls whether failed examples are included in the main output.
95/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
96#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
97pub enum FailedHandling {
98 /// Include failed examples in the main output (default)
99 #[default]
100 Keep,
101 /// Exclude failed examples from the main output
102 Skip,
103 /// Skip writing files
104 SkipNoFiles,
105}
106
107const INPUTS_HELP: &str = r#"
108Inputs can be file paths or special specifiers:
109
110 path
111 Path to an example(s) file (.md, .json, or .jsonl)
112
113 captured-after:{timestamp}
114 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
115 These are examples captured via the "Capture Edit Prediction Example" action.
116
117 rejected-after:{timestamp}
118 Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
119 These are predictions that were shown to users but rejected (useful for DPO training).
120
121 Required environment variables to connect to Snowflake:
122 EP_SNOWFLAKE_API_KEY
123 EP_SNOWFLAKE_BASE_URL
124
125 Optional:
126 EP_SNOWFLAKE_ROLE
127
128Examples:
129
130 # Read examples from a file
131 ep read examples.jsonl -o output.jsonl
132
133 # Read captured examples after a timestamp
134 ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
135
136 # Read rejected predictions for DPO training
137 ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
138
139 # Mix multiple input sources
140 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
141"#;
142
143#[derive(Subcommand, Debug, Clone)]
144enum Command {
145 /// Read examples from files or fetch from Snowflake, output as .jsonl
146 Read,
147 /// Create git worktrees for each example and load file contents
148 LoadProject,
149 /// Retrieve context for input examples.
150 Context,
151 /// Generate a prompt string for a specific model
152 FormatPrompt(FormatPromptArgs),
153 /// Runs edit prediction
154 Predict(PredictArgs),
155 /// Parse model outputs (actual_output) into unified diffs (actual_patch).
156 /// Requires format-prompt to have been run first. Uses provider from prompt.
157 ParseOutput,
158 /// Computes a score based on actual and expected patches
159 Score(PredictArgs),
160 /// Prepares a distillation dataset by copying expected outputs to
161 /// predicted outputs and removing actual outputs and prompts.
162 Distill,
163 /// Print aggregated scores
164 Eval(EvalArgs),
165 /// Generate eval examples by analyzing git commits from a repository
166 Synthesize(SynthesizeArgs),
167 /// Remove git repositories and worktrees
168 Clean,
169 /// Generate an evaluation example by splitting a chronologically-ordered commit
170 SplitCommit(SplitCommitArgs),
171 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
172 Split(SplitArgs),
173 /// Filter a JSONL dataset by programming language (based on cursor_path extension)
174 FilterLanguages(FilterLanguagesArgs),
175 /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
176 ImportBatch(ImportBatchArgs),
177 /// Assess the quality of predictions using LLM-as-a-judge
178 Qa(qa::QaArgs),
179 /// Repair predictions that received poor QA scores by generating improved predictions
180 Repair(repair::RepairArgs),
181}
182
183impl Display for Command {
184 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185 match self {
186 Command::Read => write!(f, "read"),
187 Command::LoadProject => write!(f, "load-project"),
188 Command::Context => write!(f, "context"),
189 Command::FormatPrompt(args) => {
190 write!(f, "format-prompt --provider={}", args.provider)
191 }
192 Command::Predict(args) => match &args.provider {
193 Some(provider) => write!(f, "predict --provider={}", provider),
194 None => write!(f, "predict"),
195 },
196 Command::ParseOutput => write!(f, "parse-output"),
197 Command::Score(args) => match &args.provider {
198 Some(provider) => write!(f, "score --provider={}", provider),
199 None => write!(f, "score"),
200 },
201 Command::Distill => write!(f, "distill"),
202 Command::Eval(args) => match &args.predict.provider {
203 Some(provider) => write!(f, "eval --provider={}", provider),
204 None => write!(f, "eval"),
205 },
206 Command::Synthesize(args) => {
207 write!(f, "synthesize --repos {}", args.repos.join(" "))
208 }
209 Command::Clean => write!(f, "clean"),
210 Command::SplitCommit(_) => write!(f, "split-commit"),
211 Command::Split(_) => write!(f, "split"),
212 Command::FilterLanguages(_) => write!(f, "filter-languages"),
213 Command::ImportBatch(args) => {
214 write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
215 }
216 Command::Qa(_) => {
217 write!(f, "qa")
218 }
219 Command::Repair(_) => {
220 write!(f, "repair")
221 }
222 }
223 }
224}
225
226#[derive(Debug, Args, Clone)]
227struct FormatPromptArgs {
228 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
229 provider: PredictionProvider,
230}
231
232#[derive(Debug, Args, Clone)]
233struct PredictArgs {
234 #[clap(long, short('p'))]
235 provider: Option<PredictionProvider>,
236 #[clap(long, default_value_t = 1)]
237 repetitions: usize,
238}
239
240#[derive(Debug, Args, Clone)]
241struct EvalArgs {
242 #[clap(flatten)]
243 predict: PredictArgs,
244 /// Path to write summary scores as JSON
245 #[clap(long)]
246 summary_json: Option<PathBuf>,
247}
248
249#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
250pub enum TeacherBackend {
251 Sonnet45,
252 Gpt52,
253}
254
255impl std::fmt::Display for TeacherBackend {
256 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257 match self {
258 TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
259 TeacherBackend::Gpt52 => write!(f, "gpt52"),
260 }
261 }
262}
263
264impl std::str::FromStr for TeacherBackend {
265 type Err = anyhow::Error;
266
267 fn from_str(s: &str) -> Result<Self, Self::Err> {
268 match s.to_lowercase().as_str() {
269 "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
270 "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
271 "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
272 _ => anyhow::bail!("unknown teacher backend `{s}`. Valid options: sonnet45, gpt52"),
273 }
274 }
275}
276
277impl TeacherBackend {
278 pub fn model_name(&self) -> &'static str {
279 match self {
280 TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
281 TeacherBackend::Gpt52 => "gpt-5.2",
282 }
283 }
284}
285
286#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
287enum PredictionProvider {
288 Sweep,
289 Mercury,
290 Zeta1,
291 Zeta2(ZetaVersion),
292 Teacher(TeacherBackend),
293 TeacherNonBatching(TeacherBackend),
294 Repair,
295}
296
297impl Default for PredictionProvider {
298 fn default() -> Self {
299 PredictionProvider::Zeta2(ZetaVersion::default())
300 }
301}
302
303impl std::fmt::Display for PredictionProvider {
304 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
305 match self {
306 PredictionProvider::Sweep => write!(f, "sweep"),
307 PredictionProvider::Mercury => write!(f, "mercury"),
308 PredictionProvider::Zeta1 => write!(f, "zeta1"),
309 PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
310 PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
311 PredictionProvider::TeacherNonBatching(backend) => {
312 write!(f, "teacher-non-batching:{backend}")
313 }
314 PredictionProvider::Repair => write!(f, "repair"),
315 }
316 }
317}
318
319impl std::str::FromStr for PredictionProvider {
320 type Err = anyhow::Error;
321
322 fn from_str(s: &str) -> Result<Self, Self::Err> {
323 let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
324
325 let provider_lower = provider.to_lowercase();
326 match provider_lower.as_str() {
327 "sweep" => Ok(PredictionProvider::Sweep),
328 "mercury" => Ok(PredictionProvider::Mercury),
329 "zeta1" => Ok(PredictionProvider::Zeta1),
330 "zeta2" => {
331 let version = arg.map(ZetaVersion::parse).transpose()?.unwrap_or_default();
332 Ok(PredictionProvider::Zeta2(version))
333 }
334 "teacher" => {
335 let backend = arg
336 .map(|a| a.parse())
337 .transpose()?
338 .unwrap_or(TeacherBackend::Sonnet45);
339 Ok(PredictionProvider::Teacher(backend))
340 }
341 "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
342 let backend = arg
343 .map(|a| a.parse())
344 .transpose()?
345 .unwrap_or(TeacherBackend::Sonnet45);
346 Ok(PredictionProvider::TeacherNonBatching(backend))
347 }
348 "repair" => Ok(PredictionProvider::Repair),
349 _ => {
350 anyhow::bail!(
351 "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
352 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
353 For teacher, you can specify a backend like `teacher:sonnet45` or `teacher:gpt52`.\n\
354 Available zeta versions:\n{}",
355 ZetaVersion::options_as_string()
356 )
357 }
358 }
359 }
360}
361
362impl Serialize for PredictionProvider {
363 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
364 where
365 S: Serializer,
366 {
367 serializer.serialize_str(&self.to_string())
368 }
369}
370
371impl<'de> Deserialize<'de> for PredictionProvider {
372 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
373 where
374 D: Deserializer<'de>,
375 {
376 let s = String::deserialize(deserializer)?;
377 s.parse().map_err(serde::de::Error::custom)
378 }
379}
380
381#[derive(Debug, Args, Clone)]
382struct SynthesizeArgs {
383 /// Repository URLs (git@github.com:owner/repo or https://...)
384 #[clap(long, required = true, num_args = 1..)]
385 repos: Vec<String>,
386
387 /// Number of examples to generate per repository
388 #[clap(long, default_value_t = 5)]
389 count: usize,
390
391 /// Maximum commits to scan per repository before giving up
392 #[clap(long, default_value_t = 100)]
393 max_commits: usize,
394
395 /// Ignore state file and reprocess all commits
396 #[clap(long)]
397 fresh: bool,
398}
399
400#[derive(Debug, Args, Clone)]
401struct ImportBatchArgs {
402 /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
403 #[clap(long, required = true, num_args = 1..)]
404 batch_ids: Vec<String>,
405 /// Which provider's batches to import (anthropic or openai)
406 #[clap(long, default_value = "anthropic")]
407 provider: BatchProvider,
408}
409
410#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
411enum BatchProvider {
412 Anthropic,
413 Openai,
414}
415
416impl EpArgs {
417 fn output_path(&self) -> Option<PathBuf> {
418 if self.in_place {
419 if self.inputs.len() == 1 {
420 self.inputs.first().cloned()
421 } else {
422 panic!("--in-place requires exactly one input file")
423 }
424 } else {
425 self.output.clone()
426 }
427 }
428}
429
430async fn load_examples(
431 http_client: Arc<dyn http_client::HttpClient>,
432 args: &EpArgs,
433 output_path: Option<&PathBuf>,
434 background_executor: BackgroundExecutor,
435) -> anyhow::Result<Vec<Example>> {
436 let mut captured_after_timestamps = Vec::new();
437 let mut rejected_after_timestamps = Vec::new();
438 let mut requested_after_timestamps = Vec::new();
439 let mut file_inputs = Vec::new();
440
441 for input in &args.inputs {
442 let input_string = input.to_string_lossy();
443 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
444 captured_after_timestamps.push(timestamp.to_string());
445 } else if let Some(timestamp) =
446 pull_examples::parse_rejected_after_input(input_string.as_ref())
447 {
448 rejected_after_timestamps.push(timestamp.to_string());
449 } else if let Some(timestamp) =
450 pull_examples::parse_requested_after_input(input_string.as_ref())
451 {
452 requested_after_timestamps.push(timestamp.to_string());
453 } else {
454 file_inputs.push(input.clone());
455 }
456 }
457
458 let mut examples = read_example_files(&file_inputs);
459
460 Progress::global().set_total_examples(examples.len());
461
462 let remaining_limit_for_snowflake =
463 args.limit.map(|limit| limit.saturating_sub(examples.len()));
464
465 if let Some(0) = remaining_limit_for_snowflake {
466 log::info!(
467 "skipping Snowflake inputs because --limit is already satisfied by example files"
468 );
469 } else {
470 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
471
472 if !captured_after_timestamps.is_empty() {
473 captured_after_timestamps.sort();
474
475 let mut captured_examples = pull_examples::fetch_captured_examples_after(
476 http_client.clone(),
477 &captured_after_timestamps,
478 max_rows_per_timestamp,
479 background_executor.clone(),
480 )
481 .await?;
482 examples.append(&mut captured_examples);
483 }
484
485 if !rejected_after_timestamps.is_empty() {
486 rejected_after_timestamps.sort();
487
488 let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
489 http_client.clone(),
490 &rejected_after_timestamps,
491 max_rows_per_timestamp,
492 background_executor.clone(),
493 )
494 .await?;
495 examples.append(&mut rejected_examples);
496 }
497
498 if !requested_after_timestamps.is_empty() {
499 requested_after_timestamps.sort();
500
501 let mut requested_examples = pull_examples::fetch_requested_examples_after(
502 http_client,
503 &requested_after_timestamps,
504 max_rows_per_timestamp,
505 background_executor,
506 )
507 .await?;
508 examples.append(&mut requested_examples);
509 }
510 }
511
512 crate::example::sort_examples_by_repo_and_rev(&mut examples);
513
514 if let Some(name_filter) = &args.name {
515 examples.retain(|example| example.spec.name.contains(name_filter));
516 }
517 if let Some(repo_filter) = &args.repo {
518 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
519 }
520
521 // Skip resume logic for --in-place since input and output are the same file,
522 // which would incorrectly treat all input examples as already processed.
523 if !args.in_place {
524 if let Some(path) = output_path {
525 resume_from_output(path, &mut examples);
526 }
527 }
528
529 if let Some(offset) = args.offset {
530 examples.splice(0..offset, []);
531 }
532
533 if let Some(limit) = args.limit {
534 examples.truncate(limit);
535 }
536
537 let progress = Progress::global();
538 progress.set_total_examples(examples.len());
539 progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
540
541 Ok(examples)
542}
543
544fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
545 let mut hasher = collections::FxHasher::default();
546 spec.hash(&mut hasher);
547 hasher.finish()
548}
549
550fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
551 let file = match File::open(path) {
552 Ok(f) => f,
553 Err(_) => return,
554 };
555
556 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
557
558 let reader = BufReader::new(file);
559 let mut kept_lines = Vec::new();
560 let mut kept_hashes = HashSet::default();
561
562 for line in reader.lines() {
563 let line = match line {
564 Ok(l) => l,
565 Err(_) => continue,
566 };
567
568 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
569 let hash = spec_hash(&output_example.spec);
570 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
571 kept_hashes.insert(hash);
572 kept_lines.push(line);
573 }
574 }
575 }
576
577 let total = examples.len();
578 let already_processed = kept_hashes.len();
579
580 eprintln!(
581 "Resuming: {}/{} examples already processed",
582 already_processed, total
583 );
584
585 let file = OpenOptions::new()
586 .write(true)
587 .truncate(true)
588 .open(path)
589 .expect("Failed to open output file for rewriting");
590 let mut writer = BufWriter::new(file);
591 for line in &kept_lines {
592 writeln!(writer, "{}", line).expect("Failed to write to output file");
593 }
594 writer.flush().expect("Failed to flush output file");
595
596 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
597}
598
599fn main() {
600 let args = EpArgs::parse();
601
602 if args.printenv {
603 ::util::shell_env::print_env();
604 return;
605 }
606
607 let output = args.output_path();
608
609 if args.markdown && output.is_none() {
610 eprintln!("--markdown requires -o to specify the output directory");
611 std::process::exit(1);
612 }
613
614 let command = match &args.command {
615 Some(cmd) => cmd.clone(),
616 None => {
617 EpArgs::command().print_help().unwrap();
618 return;
619 }
620 };
621
622 match &command {
623 Command::ImportBatch(import_args) => {
624 smol::block_on(async {
625 match import_args.provider {
626 BatchProvider::Anthropic => {
627 let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
628 .expect("Failed to create Anthropic client");
629 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
630 eprintln!("Error importing Anthropic batches: {:?}", e);
631 std::process::exit(1);
632 }
633 }
634 BatchProvider::Openai => {
635 let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
636 .expect("Failed to create OpenAI client");
637 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
638 eprintln!("Error importing OpenAI batches: {:?}", e);
639 std::process::exit(1);
640 }
641 }
642 }
643 println!(
644 "Successfully imported {} batch(es)",
645 import_args.batch_ids.len()
646 );
647 });
648 return;
649 }
650 Command::Clean => {
651 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
652 return;
653 }
654 Command::Synthesize(synth_args) => {
655 let Some(output_dir) = args.output else {
656 panic!("output dir is required");
657 };
658 let config = SynthesizeConfig {
659 repo_urls: synth_args.repos.clone(),
660 count: synth_args.count,
661 max_commits: synth_args.max_commits,
662 output_dir,
663 fresh: synth_args.fresh,
664 };
665 smol::block_on(async {
666 if let Err(e) = run_synthesize(config).await {
667 eprintln!("Error: {:?}", e);
668 std::process::exit(1);
669 }
670 });
671 return;
672 }
673 Command::SplitCommit(split_commit_args) => {
674 if let Err(error) = split_commit::run_split_commit(
675 split_commit_args,
676 &args.inputs,
677 output.as_ref(),
678 args.failed,
679 ) {
680 eprintln!("{error:#}");
681 std::process::exit(1);
682 }
683 return;
684 }
685 Command::Split(split_args) => {
686 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
687 eprintln!("{error:#}");
688 std::process::exit(1);
689 }
690 return;
691 }
692 Command::FilterLanguages(filter_args) => {
693 if let Err(error) =
694 run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
695 {
696 eprintln!("{error:#}");
697 std::process::exit(1);
698 }
699 return;
700 }
701 Command::Qa(qa_args) => {
702 // Read examples from input files
703 let mut examples = example::read_example_files(&args.inputs);
704
705 // Apply filters
706 if let Some(name_filter) = &args.name {
707 examples.retain(|e| e.spec.name.contains(name_filter));
708 }
709 if let Some(repo_filter) = &args.repo {
710 examples.retain(|e| e.spec.repository_url.contains(repo_filter));
711 }
712 if let Some(offset) = args.offset {
713 examples.splice(0..offset, []);
714 }
715 if let Some(limit) = args.limit {
716 examples.truncate(limit);
717 }
718
719 smol::block_on(async {
720 if let Err(e) = qa::run_qa(&mut examples, qa_args, output.as_ref()).await {
721 eprintln!("Error: {:?}", e);
722 std::process::exit(1);
723 }
724 });
725 return;
726 }
727 Command::Repair(repair_args) => {
728 // Read examples from input files
729 let mut examples = example::read_example_files(&args.inputs);
730
731 // Apply filters
732 if let Some(name_filter) = &args.name {
733 examples.retain(|e| e.spec.name.contains(name_filter));
734 }
735 if let Some(repo_filter) = &args.repo {
736 examples.retain(|e| e.spec.repository_url.contains(repo_filter));
737 }
738 if let Some(offset) = args.offset {
739 examples.splice(0..offset, []);
740 }
741 if let Some(limit) = args.limit {
742 examples.truncate(limit);
743 }
744
745 smol::block_on(async {
746 if let Err(e) =
747 repair::run_repair(&mut examples, repair_args, output.as_ref()).await
748 {
749 eprintln!("Error: {:?}", e);
750 std::process::exit(1);
751 }
752 });
753 return;
754 }
755 _ => {}
756 }
757
758 let http_client = Arc::new(ReqwestClient::new());
759 let app = Application::headless().with_http_client(http_client);
760
761 app.run(move |cx| {
762 let app_state = Arc::new(headless::init(cx));
763 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
764
765 cx.spawn(async move |cx| {
766 let result = async {
767 let examples = load_examples(
768 app_state.client.http_client(),
769 &args,
770 output.as_ref(),
771 cx.background_executor().clone(),
772 )
773 .await?;
774
775 match &command {
776 Command::Predict(args) | Command::Score(args) => {
777 predict::sync_batches(args.provider.as_ref()).await?;
778 }
779 Command::Eval(args) => {
780 predict::sync_batches(args.predict.provider.as_ref()).await?;
781 }
782 _ => (),
783 }
784
785 let failfast_on_single_example = examples.len() == 1;
786
787 // For --markdown mode, create the output directory if it doesn't exist
788 let markdown_output_dir = if args.markdown {
789 let dir = output.as_ref().expect("--markdown requires -o");
790 if !dir.exists() {
791 std::fs::create_dir_all(dir)
792 .expect("Failed to create markdown output directory");
793 }
794 Some(dir.clone())
795 } else {
796 None
797 };
798
799 // For --in-place, write to a temp file and rename at the end to avoid data loss on interruption
800 let in_place_temp_path = if args.in_place {
801 output.as_ref().map(|path| {
802 let mut temp_path = path.clone();
803 temp_path.set_extension("jsonl.tmp");
804 temp_path
805 })
806 } else {
807 None
808 };
809
810 let output_sender: Option<mpsc::UnboundedSender<String>> = if !args.markdown
811 && (args.output.is_some() || !matches!(command, Command::Eval(_)))
812 {
813 let write_path = in_place_temp_path.as_ref().or(output.as_ref());
814 write_path.map(|path| {
815 let file = if args.in_place {
816 // For --in-place, write to temp file (truncate if exists)
817 OpenOptions::new()
818 .create(true)
819 .write(true)
820 .truncate(true)
821 .open(path)
822 .expect("Failed to open temp output file")
823 } else {
824 // For regular output, append to support resuming
825 OpenOptions::new()
826 .create(true)
827 .append(true)
828 .open(path)
829 .expect("Failed to open output file")
830 };
831 let mut writer = BufWriter::new(file);
832 let (sender, mut receiver) = mpsc::unbounded::<String>();
833 cx.background_spawn(async move {
834 while let Some(line) = receiver.next().await {
835 writeln!(writer, "{}", line).expect("Failed to write example");
836 writer.flush().expect("Failed to flush output");
837 }
838 })
839 .detach();
840 sender
841 })
842 } else {
843 None
844 };
845
846 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
847 let finished_examples = Mutex::new(Vec::new());
848
849 let mut tasks = Vec::new();
850 for _ in 0..args.max_parallelism {
851 tasks.push(async {
852 loop {
853 let Some(mut repo_examples) =
854 grouped_examples.lock().unwrap().pop_front()
855 else {
856 break;
857 };
858 for example in &mut repo_examples {
859 let example_progress =
860 Progress::global().start_group(&example.spec.name);
861
862 let result = async {
863 match &command {
864 Command::Read => {}
865 Command::LoadProject => {
866 run_load_project(
867 example,
868 app_state.clone(),
869 &example_progress,
870 cx.clone(),
871 )
872 .await?;
873 }
874 Command::Context => {
875 run_context_retrieval(
876 example,
877 app_state.clone(),
878 &example_progress,
879 cx.clone(),
880 )
881 .await?;
882 }
883 Command::FormatPrompt(args) => {
884 run_format_prompt(
885 example,
886 args,
887 app_state.clone(),
888 &example_progress,
889 cx.clone(),
890 )
891 .await?;
892 }
893 Command::Predict(args) => {
894 run_prediction(
895 example,
896 args,
897 app_state.clone(),
898 &example_progress,
899 cx.clone(),
900 )
901 .await?;
902 }
903 Command::ParseOutput => {
904 parse_output::run_parse_output(example)?;
905 }
906 Command::Distill => {
907 run_distill(example).await?;
908 }
909 Command::Score(args) => {
910 run_scoring(
911 example,
912 args,
913 app_state.clone(),
914 &example_progress,
915 cx.clone(),
916 )
917 .await?;
918 }
919 Command::Eval(args) => {
920 run_scoring(
921 example,
922 &args.predict,
923 app_state.clone(),
924 &example_progress,
925 cx.clone(),
926 )
927 .await?;
928 }
929 Command::Clean
930 | Command::Synthesize(_)
931 | Command::SplitCommit(_)
932 | Command::Split(_)
933 | Command::FilterLanguages(_)
934 | Command::ImportBatch(_)
935 | Command::Qa(_)
936 | Command::Repair(_) => {
937 unreachable!()
938 }
939 }
940 anyhow::Ok(())
941 }
942 .await;
943
944 let failed = if let Err(error) = result {
945 handle_error(
946 error,
947 &args,
948 &command,
949 &app_state,
950 failfast_on_single_example,
951 &example,
952 )
953 .await;
954 true
955 } else {
956 false
957 };
958
959 let should_write = !failed || args.failed == FailedHandling::Keep;
960 if should_write {
961 if let Some(ref markdown_dir) = markdown_output_dir {
962 let filename = format!("{}.md", example.spec.filename());
963 let path = markdown_dir.join(&filename);
964 let markdown = example.spec.to_markdown();
965 std::fs::write(&path, &markdown)
966 .expect("Failed to write markdown file");
967 } else if let Some(ref mut sender) = output_sender.clone() {
968 let line = serde_json::to_string(&example).unwrap();
969 sender
970 .send(line)
971 .await
972 .expect("Failed to send to output writer");
973 } else if args.output.is_none()
974 && !matches!(command, Command::Eval(_))
975 {
976 let line = serde_json::to_string(&example).unwrap();
977 println!("{}", line);
978 }
979 }
980 }
981
982 let repo_url = &repo_examples.first().unwrap().spec.repository_url;
983 let project = repo_examples
984 .iter()
985 .find_map(|e| e.state.as_ref().map(|s| s.project.clone()))
986 .or_else(|| app_state.project_cache.get(repo_url));
987
988 if let Some(project) = project {
989 let mut cx = cx.clone();
990
991 let shutdown_task: Task<()> =
992 project.update(&mut cx, |project, cx| {
993 let lsp_store = project.lsp_store();
994 lsp_store.update(cx, |lsp_store, cx| {
995 lsp_store.shutdown_all_language_servers(cx)
996 })
997 });
998
999 shutdown_task.await;
1000
1001 if let Some(ep_store) =
1002 cx.update(|cx| EditPredictionStore::try_global(cx))
1003 {
1004 ep_store.update(&mut cx, |store, _| {
1005 store.remove_project(&project);
1006 });
1007 }
1008 }
1009
1010 app_state.project_cache.remove(repo_url);
1011 for example in &mut repo_examples {
1012 example.state.take();
1013 }
1014 finished_examples
1015 .lock()
1016 .unwrap()
1017 .extend_from_slice(&repo_examples);
1018 }
1019 });
1020 }
1021 futures::future::join_all(tasks).await;
1022
1023 Progress::global().finalize();
1024
1025 match &command {
1026 Command::Predict(args) | Command::Score(args) => {
1027 predict::sync_batches(args.provider.as_ref()).await?;
1028 }
1029 Command::Eval(args) => {
1030 predict::sync_batches(args.predict.provider.as_ref()).await?;
1031 }
1032 _ => (),
1033 }
1034
1035 match &command {
1036 Command::Eval(args) => {
1037 let examples = finished_examples.lock().unwrap();
1038 score::print_report(&examples);
1039 if let Some(summary_path) = &args.summary_json {
1040 score::write_summary_json(&examples, summary_path)?;
1041 }
1042 }
1043 _ => (),
1044 };
1045
1046 // For --in-place, atomically rename temp file to original
1047 if let (Some(temp_path), Some(final_path)) = (&in_place_temp_path, &output) {
1048 std::fs::rename(temp_path, final_path)
1049 .expect("Failed to rename temp file to final output");
1050 }
1051
1052 anyhow::Ok(())
1053 }
1054 .await;
1055
1056 if let Err(e) = result {
1057 panic!("Fatal error: {:?}", e);
1058 }
1059
1060 let _ = cx.update(|cx| cx.quit());
1061 })
1062 .detach();
1063 });
1064}
1065
1066async fn handle_error(
1067 error: anyhow::Error,
1068 args: &EpArgs,
1069 command: &Command,
1070 app_state: &Arc<headless::EpAppState>,
1071 failfast_on_single_example: bool,
1072 example: &Example,
1073) {
1074 Progress::global().increment_failed();
1075
1076 let msg;
1077 if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1078 let example_name = example.spec.filename();
1079
1080 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1081 app_state
1082 .fs
1083 .write(
1084 &failed_example_path,
1085 &serde_json::to_vec_pretty(&example).unwrap(),
1086 )
1087 .await
1088 .unwrap();
1089 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1090 app_state
1091 .fs
1092 .write(&err_path, format!("{error:?}").as_bytes())
1093 .await
1094 .unwrap();
1095
1096 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1097 let mut file = OpenOptions::new()
1098 .create(true)
1099 .append(true)
1100 .open(&failed_jsonl_path)
1101 .expect("Failed to open failed.jsonl");
1102 writeln!(file, "{}", serde_json::to_string(example).unwrap())
1103 .expect("Failed to write to failed.jsonl");
1104
1105 let cursor_path = example
1106 .repo_name()
1107 .unwrap()
1108 .worktree_path()
1109 .join(&example.spec.cursor_path);
1110 msg = format!(
1111 indoc::indoc! {"
1112 While processing \"{}\":
1113
1114 \x1b[31m{:?}\x1b[0m
1115
1116 Example: \x1b[36m{}\x1b[0m
1117 Error file: \x1b[36m{}\x1b[0m
1118 Cursor file: \x1b[36m{}\x1b[0m
1119 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1120 "},
1121 example.spec.name,
1122 error,
1123 failed_example_path.display(),
1124 err_path.display(),
1125 cursor_path.display(),
1126 command,
1127 failed_example_path.display(),
1128 );
1129 } else {
1130 msg = format!(
1131 indoc::indoc! {"
1132 While processing \"{}\":
1133
1134 \x1b[31m{:?}\x1b[0m
1135 "},
1136 example.spec.name, error
1137 );
1138 }
1139
1140 if args.failfast || failfast_on_single_example {
1141 Progress::global().finalize();
1142 panic!("{}", msg);
1143 } else {
1144 log::error!("{}", msg);
1145 }
1146}