1mod anthropic_client;
2mod distill;
3mod example;
4mod filter_languages;
5mod format_prompt;
6mod git;
7mod headless;
8mod load_project;
9mod metrics;
10mod openai_client;
11mod parse_output;
12mod paths;
13mod predict;
14mod progress;
15mod prompt_assets;
16mod pull_examples;
17mod qa;
18mod reorder_patch;
19mod repair;
20mod retrieve_context;
21mod reversal_tracking;
22mod score;
23mod split_commit;
24mod split_dataset;
25mod synthesize;
26mod word_diff;
27use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
28use collections::HashSet;
29use edit_prediction::EditPredictionStore;
30use futures::channel::mpsc;
31use futures::{SinkExt as _, StreamExt as _};
32use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
33use zeta_prompt::ZetaVersion;
34
35use reqwest_client::ReqwestClient;
36use serde::{Deserialize, Deserializer, Serialize, Serializer};
37use std::fmt::Display;
38use std::fs::{File, OpenOptions};
39use std::hash::{Hash, Hasher};
40use std::io::{BufRead, BufReader, BufWriter, Write};
41use std::sync::Mutex;
42use std::{path::PathBuf, sync::Arc};
43
44use crate::distill::run_distill;
45use crate::example::{Example, group_examples_by_repo, read_example_files};
46use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
47use crate::format_prompt::run_format_prompt;
48use crate::load_project::run_load_project;
49use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
50use crate::predict::run_prediction;
51use crate::progress::Progress;
52use crate::retrieve_context::run_context_retrieval;
53use crate::score::run_scoring;
54use crate::split_commit::SplitCommitArgs;
55use crate::split_dataset::SplitArgs;
56use crate::synthesize::{SynthesizeConfig, run_synthesize};
57
58#[derive(Parser, Debug)]
59#[command(name = "ep")]
60struct EpArgs {
61 #[arg(long, default_value_t = false)]
62 printenv: bool,
63 #[clap(long, default_value_t = 10, global = true)]
64 max_parallelism: usize,
65 #[clap(long, global = true)]
66 limit: Option<usize>,
67 #[clap(long, global = true)]
68 offset: Option<usize>,
69 /// Filter examples by name
70 #[clap(long, global = true)]
71 name: Option<String>,
72 /// Filter examples by repository
73 #[clap(long, global = true)]
74 repo: Option<String>,
75 #[command(subcommand)]
76 command: Option<Command>,
77 #[clap(global = true, help = INPUTS_HELP)]
78 inputs: Vec<PathBuf>,
79 #[arg(long, short, global = true)]
80 output: Option<PathBuf>,
81 #[arg(long, short, global = true)]
82 in_place: bool,
83 #[arg(long, short, global = true)]
84 failfast: bool,
85 /// How to handle failed examples in output: keep them or skip them.
86 /// Failed examples are always logged to the run's failed directory.
87 #[arg(long, global = true, default_value = "keep")]
88 failed: FailedHandling,
89 /// Output as markdown files instead of JSONL. When set, -o specifies a directory
90 /// where one .md file per example will be written (named after each example).
91 #[arg(long, short, global = true)]
92 markdown: bool,
93}
94
95/// Controls whether failed examples are included in the main output.
96/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
97#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
98pub enum FailedHandling {
99 /// Include failed examples in the main output (default)
100 #[default]
101 Keep,
102 /// Exclude failed examples from the main output
103 Skip,
104 /// Skip writing files
105 SkipNoFiles,
106}
107
108const INPUTS_HELP: &str = r#"
109Inputs can be file paths or special specifiers:
110
111 path
112 Path to an example(s) file (.md, .json, or .jsonl)
113
114 captured-after:{timestamp}
115 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
116 These are examples captured via the "Capture Edit Prediction Example" action.
117
118 rejected-after:{timestamp}
119 Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
120 These are predictions that were shown to users but rejected (useful for DPO training).
121
122 rated-after:{timestamp}
123 Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
124 These are predictions that users explicitly rated as positive or negative via the
125 rate completions modal. Only zeta2 predictions are included.
126 - Positive ratings: output becomes expected_patches
127 - Negative ratings: output becomes rejected_patch
128
129 rated-positive-after:{timestamp}
130 Same as rated-after, but only fetches positively rated predictions.
131
132 rated-negative-after:{timestamp}
133 Same as rated-after, but only fetches negatively rated predictions.
134
135 Required environment variables to connect to Snowflake:
136 EP_SNOWFLAKE_API_KEY
137 EP_SNOWFLAKE_BASE_URL
138
139 Optional:
140 EP_SNOWFLAKE_ROLE
141
142Examples:
143
144 # Read examples from a file
145 ep read examples.jsonl -o output.jsonl
146
147 # Read captured examples after a timestamp
148 ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
149
150 # Read rejected predictions for DPO training
151 ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
152
153 # Read user-rated predictions
154 ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
155
156 # Read only positively rated predictions
157 ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
158
159 # Read only negatively rated predictions
160 ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
161
162 # Mix multiple input sources
163 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
164"#;
165
166#[derive(Subcommand, Debug, Clone)]
167enum Command {
168 /// Read examples from files or fetch from Snowflake, output as .jsonl
169 Read,
170 /// Create git worktrees for each example and load file contents
171 LoadProject,
172 /// Retrieve context for input examples.
173 Context,
174 /// Generate a prompt string for a specific model
175 FormatPrompt(FormatPromptArgs),
176 /// Runs edit prediction
177 Predict(PredictArgs),
178 /// Parse model outputs (actual_output) into unified diffs (actual_patch).
179 /// Requires format-prompt to have been run first. Uses provider from prompt.
180 ParseOutput,
181 /// Computes a score based on actual and expected patches
182 Score(PredictArgs),
183 /// Prepares a distillation dataset by copying expected outputs to
184 /// predicted outputs and removing actual outputs and prompts.
185 Distill,
186 /// Print aggregated scores
187 Eval(EvalArgs),
188 /// Generate eval examples by analyzing git commits from a repository
189 Synthesize(SynthesizeArgs),
190 /// Remove git repositories and worktrees
191 Clean,
192 /// Generate an evaluation example by splitting a chronologically-ordered commit
193 SplitCommit(SplitCommitArgs),
194 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
195 Split(SplitArgs),
196 /// Filter a JSONL dataset by programming language (based on cursor_path extension)
197 FilterLanguages(FilterLanguagesArgs),
198 /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
199 ImportBatch(ImportBatchArgs),
200 /// Assess the quality of predictions using LLM-as-a-judge
201 Qa(qa::QaArgs),
202 /// Repair predictions that received poor QA scores by generating improved predictions
203 Repair(repair::RepairArgs),
204}
205
206impl Display for Command {
207 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
208 match self {
209 Command::Read => write!(f, "read"),
210 Command::LoadProject => write!(f, "load-project"),
211 Command::Context => write!(f, "context"),
212 Command::FormatPrompt(args) => {
213 write!(f, "format-prompt --provider={}", args.provider)
214 }
215 Command::Predict(args) => match &args.provider {
216 Some(provider) => write!(f, "predict --provider={}", provider),
217 None => write!(f, "predict"),
218 },
219 Command::ParseOutput => write!(f, "parse-output"),
220 Command::Score(args) => match &args.provider {
221 Some(provider) => write!(f, "score --provider={}", provider),
222 None => write!(f, "score"),
223 },
224 Command::Distill => write!(f, "distill"),
225 Command::Eval(args) => match &args.predict.provider {
226 Some(provider) => write!(f, "eval --provider={}", provider),
227 None => write!(f, "eval"),
228 },
229 Command::Synthesize(args) => {
230 write!(f, "synthesize --repos {}", args.repos.join(" "))
231 }
232 Command::Clean => write!(f, "clean"),
233 Command::SplitCommit(_) => write!(f, "split-commit"),
234 Command::Split(_) => write!(f, "split"),
235 Command::FilterLanguages(_) => write!(f, "filter-languages"),
236 Command::ImportBatch(args) => {
237 write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
238 }
239 Command::Qa(_) => {
240 write!(f, "qa")
241 }
242 Command::Repair(_) => {
243 write!(f, "repair")
244 }
245 }
246 }
247}
248
249#[derive(Debug, Args, Clone)]
250struct FormatPromptArgs {
251 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
252 provider: PredictionProvider,
253}
254
255#[derive(Debug, Args, Clone)]
256struct PredictArgs {
257 #[clap(long, short('p'))]
258 provider: Option<PredictionProvider>,
259 #[clap(long, default_value_t = 1)]
260 repetitions: usize,
261 /// Only use cached responses, don't queue new requests for batching
262 #[clap(long)]
263 cache_only: bool,
264}
265
266#[derive(Debug, Args, Clone)]
267struct EvalArgs {
268 #[clap(flatten)]
269 predict: PredictArgs,
270 /// Path to write summary scores as JSON
271 #[clap(long)]
272 summary_json: Option<PathBuf>,
273}
274
275#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
276pub enum TeacherBackend {
277 Sonnet45,
278 Gpt52,
279}
280
281impl std::fmt::Display for TeacherBackend {
282 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
283 match self {
284 TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
285 TeacherBackend::Gpt52 => write!(f, "gpt52"),
286 }
287 }
288}
289
290impl std::str::FromStr for TeacherBackend {
291 type Err = anyhow::Error;
292
293 fn from_str(s: &str) -> Result<Self, Self::Err> {
294 match s.to_lowercase().as_str() {
295 "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
296 "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
297 "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
298 _ => anyhow::bail!("unknown teacher backend `{s}`. Valid options: sonnet45, gpt52"),
299 }
300 }
301}
302
303impl TeacherBackend {
304 pub fn model_name(&self) -> &'static str {
305 match self {
306 TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
307 TeacherBackend::Gpt52 => "gpt-5.2",
308 }
309 }
310}
311
312#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
313enum PredictionProvider {
314 Sweep,
315 Mercury,
316 Zeta1,
317 Zeta2(ZetaVersion),
318 Teacher(TeacherBackend),
319 TeacherNonBatching(TeacherBackend),
320 Repair,
321}
322
323impl Default for PredictionProvider {
324 fn default() -> Self {
325 PredictionProvider::Zeta2(ZetaVersion::default())
326 }
327}
328
329impl std::fmt::Display for PredictionProvider {
330 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
331 match self {
332 PredictionProvider::Sweep => write!(f, "sweep"),
333 PredictionProvider::Mercury => write!(f, "mercury"),
334 PredictionProvider::Zeta1 => write!(f, "zeta1"),
335 PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
336 PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
337 PredictionProvider::TeacherNonBatching(backend) => {
338 write!(f, "teacher-non-batching:{backend}")
339 }
340 PredictionProvider::Repair => write!(f, "repair"),
341 }
342 }
343}
344
345impl std::str::FromStr for PredictionProvider {
346 type Err = anyhow::Error;
347
348 fn from_str(s: &str) -> Result<Self, Self::Err> {
349 let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
350
351 let provider_lower = provider.to_lowercase();
352 match provider_lower.as_str() {
353 "sweep" => Ok(PredictionProvider::Sweep),
354 "mercury" => Ok(PredictionProvider::Mercury),
355 "zeta1" => Ok(PredictionProvider::Zeta1),
356 "zeta2" => {
357 let version = arg.map(ZetaVersion::parse).transpose()?.unwrap_or_default();
358 Ok(PredictionProvider::Zeta2(version))
359 }
360 "teacher" => {
361 let backend = arg
362 .map(|a| a.parse())
363 .transpose()?
364 .unwrap_or(TeacherBackend::Sonnet45);
365 Ok(PredictionProvider::Teacher(backend))
366 }
367 "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
368 let backend = arg
369 .map(|a| a.parse())
370 .transpose()?
371 .unwrap_or(TeacherBackend::Sonnet45);
372 Ok(PredictionProvider::TeacherNonBatching(backend))
373 }
374 "repair" => Ok(PredictionProvider::Repair),
375 _ => {
376 anyhow::bail!(
377 "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
378 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
379 For teacher, you can specify a backend like `teacher:sonnet45` or `teacher:gpt52`.\n\
380 Available zeta versions:\n{}",
381 ZetaVersion::options_as_string()
382 )
383 }
384 }
385 }
386}
387
388impl Serialize for PredictionProvider {
389 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
390 where
391 S: Serializer,
392 {
393 serializer.serialize_str(&self.to_string())
394 }
395}
396
397impl<'de> Deserialize<'de> for PredictionProvider {
398 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
399 where
400 D: Deserializer<'de>,
401 {
402 let s = String::deserialize(deserializer)?;
403 s.parse().map_err(serde::de::Error::custom)
404 }
405}
406
407#[derive(Debug, Args, Clone)]
408struct SynthesizeArgs {
409 /// Repository URLs (git@github.com:owner/repo or https://...)
410 #[clap(long, required = true, num_args = 1..)]
411 repos: Vec<String>,
412
413 /// Number of examples to generate per repository
414 #[clap(long, default_value_t = 5)]
415 count: usize,
416
417 /// Maximum commits to scan per repository before giving up
418 #[clap(long, default_value_t = 100)]
419 max_commits: usize,
420
421 /// Ignore state file and reprocess all commits
422 #[clap(long)]
423 fresh: bool,
424}
425
426#[derive(Debug, Args, Clone)]
427struct ImportBatchArgs {
428 /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
429 #[clap(long, required = true, num_args = 1..)]
430 batch_ids: Vec<String>,
431 /// Which provider's batches to import (anthropic or openai)
432 #[clap(long, default_value = "anthropic")]
433 provider: BatchProvider,
434}
435
436#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
437enum BatchProvider {
438 Anthropic,
439 Openai,
440}
441
442impl EpArgs {
443 fn output_path(&self) -> Option<PathBuf> {
444 if self.in_place {
445 if self.inputs.len() == 1 {
446 self.inputs.first().cloned()
447 } else {
448 panic!("--in-place requires exactly one input file")
449 }
450 } else {
451 self.output.clone()
452 }
453 }
454}
455
456async fn load_examples(
457 http_client: Arc<dyn http_client::HttpClient>,
458 args: &EpArgs,
459 output_path: Option<&PathBuf>,
460 background_executor: BackgroundExecutor,
461) -> anyhow::Result<Vec<Example>> {
462 let mut captured_after_timestamps = Vec::new();
463 let mut rejected_after_timestamps = Vec::new();
464 let mut requested_after_timestamps = Vec::new();
465 let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
466 Vec::new();
467 let mut file_inputs = Vec::new();
468
469 for input in &args.inputs {
470 let input_string = input.to_string_lossy();
471 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
472 captured_after_timestamps.push(timestamp.to_string());
473 } else if let Some(timestamp) =
474 pull_examples::parse_rejected_after_input(input_string.as_ref())
475 {
476 rejected_after_timestamps.push(timestamp.to_string());
477 } else if let Some(timestamp) =
478 pull_examples::parse_requested_after_input(input_string.as_ref())
479 {
480 requested_after_timestamps.push(timestamp.to_string());
481 } else if let Some((timestamp, rating_filter)) =
482 pull_examples::parse_rated_after_input(input_string.as_ref())
483 {
484 rated_after_inputs.push((timestamp.to_string(), rating_filter));
485 } else {
486 file_inputs.push(input.clone());
487 }
488 }
489
490 let mut examples = read_example_files(&file_inputs);
491
492 Progress::global().set_total_examples(examples.len());
493
494 let remaining_limit_for_snowflake =
495 args.limit.map(|limit| limit.saturating_sub(examples.len()));
496
497 if let Some(0) = remaining_limit_for_snowflake {
498 log::info!(
499 "skipping Snowflake inputs because --limit is already satisfied by example files"
500 );
501 } else {
502 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
503
504 if !captured_after_timestamps.is_empty() {
505 captured_after_timestamps.sort();
506
507 let mut captured_examples = pull_examples::fetch_captured_examples_after(
508 http_client.clone(),
509 &captured_after_timestamps,
510 max_rows_per_timestamp,
511 background_executor.clone(),
512 )
513 .await?;
514 examples.append(&mut captured_examples);
515 }
516
517 if !rejected_after_timestamps.is_empty() {
518 rejected_after_timestamps.sort();
519
520 let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
521 http_client.clone(),
522 &rejected_after_timestamps,
523 max_rows_per_timestamp,
524 background_executor.clone(),
525 )
526 .await?;
527 examples.append(&mut rejected_examples);
528 }
529
530 if !requested_after_timestamps.is_empty() {
531 requested_after_timestamps.sort();
532
533 let mut requested_examples = pull_examples::fetch_requested_examples_after(
534 http_client.clone(),
535 &requested_after_timestamps,
536 max_rows_per_timestamp,
537 background_executor.clone(),
538 )
539 .await?;
540 examples.append(&mut requested_examples);
541 }
542
543 if !rated_after_inputs.is_empty() {
544 rated_after_inputs.sort();
545
546 let mut rated_examples = pull_examples::fetch_rated_examples_after(
547 http_client,
548 &rated_after_inputs,
549 max_rows_per_timestamp,
550 background_executor,
551 )
552 .await?;
553 examples.append(&mut rated_examples);
554 }
555 }
556
557 crate::example::sort_examples_by_repo_and_rev(&mut examples);
558
559 if let Some(name_filter) = &args.name {
560 examples.retain(|example| example.spec.name.contains(name_filter));
561 }
562 if let Some(repo_filter) = &args.repo {
563 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
564 }
565
566 // Skip resume logic for --in-place since input and output are the same file,
567 // which would incorrectly treat all input examples as already processed.
568 if !args.in_place {
569 if let Some(path) = output_path {
570 resume_from_output(path, &mut examples);
571 }
572 }
573
574 if let Some(offset) = args.offset {
575 examples.splice(0..offset, []);
576 }
577
578 if let Some(limit) = args.limit {
579 examples.truncate(limit);
580 }
581
582 let progress = Progress::global();
583 progress.set_total_examples(examples.len());
584 progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
585
586 Ok(examples)
587}
588
589fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
590 let mut hasher = collections::FxHasher::default();
591 spec.hash(&mut hasher);
592 hasher.finish()
593}
594
595fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
596 let file = match File::open(path) {
597 Ok(f) => f,
598 Err(_) => return,
599 };
600
601 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
602
603 let reader = BufReader::new(file);
604 let mut kept_lines = Vec::new();
605 let mut kept_hashes = HashSet::default();
606
607 for line in reader.lines() {
608 let line = match line {
609 Ok(l) => l,
610 Err(_) => continue,
611 };
612
613 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
614 let hash = spec_hash(&output_example.spec);
615 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
616 kept_hashes.insert(hash);
617 kept_lines.push(line);
618 }
619 }
620 }
621
622 let total = examples.len();
623 let already_processed = kept_hashes.len();
624
625 eprintln!(
626 "Resuming: {}/{} examples already processed",
627 already_processed, total
628 );
629
630 let file = OpenOptions::new()
631 .write(true)
632 .truncate(true)
633 .open(path)
634 .expect("Failed to open output file for rewriting");
635 let mut writer = BufWriter::new(file);
636 for line in &kept_lines {
637 writeln!(writer, "{}", line).expect("Failed to write to output file");
638 }
639 writer.flush().expect("Failed to flush output file");
640
641 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
642}
643
644fn main() {
645 let args = EpArgs::parse();
646
647 if args.printenv {
648 ::util::shell_env::print_env();
649 return;
650 }
651
652 let output = args.output_path();
653
654 if args.markdown && output.is_none() {
655 eprintln!("--markdown requires -o to specify the output directory");
656 std::process::exit(1);
657 }
658
659 let command = match &args.command {
660 Some(cmd) => cmd.clone(),
661 None => {
662 EpArgs::command().print_help().unwrap();
663 return;
664 }
665 };
666
667 match &command {
668 Command::ImportBatch(import_args) => {
669 smol::block_on(async {
670 match import_args.provider {
671 BatchProvider::Anthropic => {
672 let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
673 .expect("Failed to create Anthropic client");
674 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
675 eprintln!("Error importing Anthropic batches: {:?}", e);
676 std::process::exit(1);
677 }
678 }
679 BatchProvider::Openai => {
680 let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
681 .expect("Failed to create OpenAI client");
682 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
683 eprintln!("Error importing OpenAI batches: {:?}", e);
684 std::process::exit(1);
685 }
686 }
687 }
688 println!(
689 "Successfully imported {} batch(es)",
690 import_args.batch_ids.len()
691 );
692 });
693 return;
694 }
695 Command::Clean => {
696 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
697 return;
698 }
699 Command::Synthesize(synth_args) => {
700 let Some(output_dir) = args.output else {
701 panic!("output dir is required");
702 };
703 let config = SynthesizeConfig {
704 repo_urls: synth_args.repos.clone(),
705 count: synth_args.count,
706 max_commits: synth_args.max_commits,
707 output_dir,
708 fresh: synth_args.fresh,
709 };
710 smol::block_on(async {
711 if let Err(e) = run_synthesize(config).await {
712 eprintln!("Error: {:?}", e);
713 std::process::exit(1);
714 }
715 });
716 return;
717 }
718 Command::SplitCommit(split_commit_args) => {
719 if let Err(error) = split_commit::run_split_commit(
720 split_commit_args,
721 &args.inputs,
722 output.as_ref(),
723 args.failed,
724 ) {
725 eprintln!("{error:#}");
726 std::process::exit(1);
727 }
728 return;
729 }
730 Command::Split(split_args) => {
731 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
732 eprintln!("{error:#}");
733 std::process::exit(1);
734 }
735 return;
736 }
737 Command::FilterLanguages(filter_args) => {
738 if let Err(error) =
739 run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
740 {
741 eprintln!("{error:#}");
742 std::process::exit(1);
743 }
744 return;
745 }
746 Command::Qa(qa_args) => {
747 // Read examples from input files
748 let mut examples = example::read_example_files(&args.inputs);
749
750 // Apply filters
751 if let Some(name_filter) = &args.name {
752 examples.retain(|e| e.spec.name.contains(name_filter));
753 }
754 if let Some(repo_filter) = &args.repo {
755 examples.retain(|e| e.spec.repository_url.contains(repo_filter));
756 }
757 if let Some(offset) = args.offset {
758 examples.splice(0..offset, []);
759 }
760 if let Some(limit) = args.limit {
761 examples.truncate(limit);
762 }
763
764 smol::block_on(async {
765 if let Err(e) = qa::run_qa(&mut examples, qa_args, output.as_ref()).await {
766 eprintln!("Error: {:?}", e);
767 std::process::exit(1);
768 }
769 });
770 return;
771 }
772 Command::Repair(repair_args) => {
773 // Read examples from input files
774 let mut examples = example::read_example_files(&args.inputs);
775
776 // Apply filters
777 if let Some(name_filter) = &args.name {
778 examples.retain(|e| e.spec.name.contains(name_filter));
779 }
780 if let Some(repo_filter) = &args.repo {
781 examples.retain(|e| e.spec.repository_url.contains(repo_filter));
782 }
783 if let Some(offset) = args.offset {
784 examples.splice(0..offset, []);
785 }
786 if let Some(limit) = args.limit {
787 examples.truncate(limit);
788 }
789
790 smol::block_on(async {
791 if let Err(e) =
792 repair::run_repair(&mut examples, repair_args, output.as_ref()).await
793 {
794 eprintln!("Error: {:?}", e);
795 std::process::exit(1);
796 }
797 });
798 return;
799 }
800 _ => {}
801 }
802
803 let http_client = Arc::new(ReqwestClient::new());
804 let app = Application::headless().with_http_client(http_client);
805
806 app.run(move |cx| {
807 let app_state = Arc::new(headless::init(cx));
808 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
809
810 cx.spawn(async move |cx| {
811 let result = async {
812 let examples = load_examples(
813 app_state.client.http_client(),
814 &args,
815 output.as_ref(),
816 cx.background_executor().clone(),
817 )
818 .await?;
819
820 match &command {
821 Command::Predict(args) | Command::Score(args) => {
822 predict::sync_batches(args.provider.as_ref()).await?;
823 }
824 Command::Eval(args) => {
825 predict::sync_batches(args.predict.provider.as_ref()).await?;
826 }
827 _ => (),
828 }
829
830 let failfast_on_single_example = examples.len() == 1;
831
832 // For --markdown mode, create the output directory if it doesn't exist
833 if args.markdown {
834 let dir = output.as_ref().expect("--markdown requires -o");
835 if !dir.exists() {
836 std::fs::create_dir_all(dir)
837 .expect("Failed to create markdown output directory");
838 }
839 }
840
841 // Set up JSONL output writer (not used in markdown mode)
842 let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
843 let mut in_place_temp_path: Option<PathBuf> = None;
844 if !args.markdown
845 && let Some(output_path) = output.as_ref()
846 {
847 let write_path = if args.in_place {
848 let temp = output_path.with_extension("jsonl.tmp");
849 in_place_temp_path = Some(temp.clone());
850 temp
851 } else {
852 output_path.clone()
853 };
854
855 let file = OpenOptions::new()
856 .create(true)
857 .write(true)
858 .truncate(args.in_place)
859 .append(!args.in_place)
860 .open(&write_path)
861 .expect("Failed to open output file");
862
863 let mut writer = BufWriter::new(file);
864 let (sender, mut receiver) = mpsc::unbounded::<String>();
865 cx.background_spawn(async move {
866 while let Some(line) = receiver.next().await {
867 writeln!(writer, "{}", line).expect("Failed to write example");
868 writer.flush().expect("Failed to flush output");
869 }
870 })
871 .detach();
872 output_sender = Some(sender);
873 }
874
875 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
876 let finished_examples = Mutex::new(Vec::new());
877
878 let mut tasks = Vec::new();
879 for _ in 0..args.max_parallelism {
880 tasks.push(async {
881 loop {
882 let Some(mut repo_examples) =
883 grouped_examples.lock().unwrap().pop_front()
884 else {
885 break;
886 };
887 for example in &mut repo_examples {
888 let example_progress =
889 Progress::global().start_group(&example.spec.name);
890
891 let result = async {
892 match &command {
893 Command::Read => {}
894 Command::LoadProject => {
895 run_load_project(
896 example,
897 app_state.clone(),
898 &example_progress,
899 cx.clone(),
900 )
901 .await?;
902 }
903 Command::Context => {
904 run_context_retrieval(
905 example,
906 app_state.clone(),
907 &example_progress,
908 cx.clone(),
909 )
910 .await?;
911 }
912 Command::FormatPrompt(args) => {
913 run_format_prompt(
914 example,
915 args,
916 app_state.clone(),
917 &example_progress,
918 cx.clone(),
919 )
920 .await?;
921 }
922 Command::Predict(args) => {
923 run_prediction(
924 example,
925 args,
926 app_state.clone(),
927 &example_progress,
928 cx.clone(),
929 )
930 .await?;
931 }
932 Command::ParseOutput => {
933 parse_output::run_parse_output(example)?;
934 }
935 Command::Distill => {
936 run_distill(example).await?;
937 }
938 Command::Score(args) => {
939 run_scoring(
940 example,
941 args,
942 app_state.clone(),
943 &example_progress,
944 cx.clone(),
945 )
946 .await?;
947 }
948 Command::Eval(args) => {
949 run_scoring(
950 example,
951 &args.predict,
952 app_state.clone(),
953 &example_progress,
954 cx.clone(),
955 )
956 .await?;
957 }
958 Command::Clean
959 | Command::Synthesize(_)
960 | Command::SplitCommit(_)
961 | Command::Split(_)
962 | Command::FilterLanguages(_)
963 | Command::ImportBatch(_)
964 | Command::Qa(_)
965 | Command::Repair(_) => {
966 unreachable!()
967 }
968 }
969 anyhow::Ok(())
970 }
971 .await;
972
973 let failed = if let Err(error) = result {
974 handle_error(
975 error,
976 &args,
977 &command,
978 &app_state,
979 failfast_on_single_example,
980 &example,
981 )
982 .await;
983 true
984 } else {
985 false
986 };
987
988 let should_write = !failed || args.failed == FailedHandling::Keep;
989 if should_write {
990 if args.markdown {
991 let markdown_dir =
992 output.as_ref().expect("--markdown requires -o");
993 let filename = format!("{}.md", example.spec.filename());
994 let path = markdown_dir.join(&filename);
995 let markdown = example.spec.to_markdown();
996 std::fs::write(&path, &markdown)
997 .expect("Failed to write markdown file");
998 } else if let Some(ref mut sender) = output_sender.clone() {
999 let line = serde_json::to_string(&example).unwrap();
1000 sender
1001 .send(line)
1002 .await
1003 .expect("Failed to send to output writer");
1004 } else if args.output.is_none()
1005 && !matches!(command, Command::Eval(_))
1006 {
1007 let line = serde_json::to_string(&example).unwrap();
1008 println!("{}", line);
1009 }
1010 }
1011 }
1012
1013 let repo_url = &repo_examples.first().unwrap().spec.repository_url;
1014 let project = repo_examples
1015 .iter()
1016 .find_map(|e| e.state.as_ref().map(|s| s.project.clone()))
1017 .or_else(|| app_state.project_cache.get(repo_url));
1018
1019 if let Some(project) = project {
1020 let mut cx = cx.clone();
1021
1022 let shutdown_task: Task<()> =
1023 project.update(&mut cx, |project, cx| {
1024 let lsp_store = project.lsp_store();
1025 lsp_store.update(cx, |lsp_store, cx| {
1026 lsp_store.shutdown_all_language_servers(cx)
1027 })
1028 });
1029
1030 shutdown_task.await;
1031
1032 if let Some(ep_store) =
1033 cx.update(|cx| EditPredictionStore::try_global(cx))
1034 {
1035 ep_store.update(&mut cx, |store, _| {
1036 store.remove_project(&project);
1037 });
1038 }
1039 }
1040
1041 app_state.project_cache.remove(repo_url);
1042 for example in &mut repo_examples {
1043 example.state.take();
1044 }
1045 finished_examples
1046 .lock()
1047 .unwrap()
1048 .extend_from_slice(&repo_examples);
1049 }
1050 });
1051 }
1052 futures::future::join_all(tasks).await;
1053
1054 Progress::global().finalize();
1055
1056 match &command {
1057 Command::Predict(args) | Command::Score(args) => {
1058 predict::sync_batches(args.provider.as_ref()).await?;
1059 }
1060 Command::Eval(args) => {
1061 predict::sync_batches(args.predict.provider.as_ref()).await?;
1062 }
1063 _ => (),
1064 }
1065
1066 match &command {
1067 Command::Eval(args) => {
1068 let examples = finished_examples.lock().unwrap();
1069 score::print_report(&examples);
1070 if let Some(summary_path) = &args.summary_json {
1071 score::write_summary_json(&examples, summary_path)?;
1072 }
1073 }
1074 _ => (),
1075 };
1076
1077 // For --in-place, atomically rename temp file to original
1078 if let Some(temp_path) = &in_place_temp_path {
1079 let final_path = output.as_ref().expect("in_place_temp_path requires output");
1080 std::fs::rename(temp_path, final_path)
1081 .expect("Failed to rename temp file to final output");
1082 }
1083
1084 anyhow::Ok(())
1085 }
1086 .await;
1087
1088 if let Err(e) = result {
1089 panic!("Fatal error: {:?}", e);
1090 }
1091
1092 let _ = cx.update(|cx| cx.quit());
1093 })
1094 .detach();
1095 });
1096}
1097
1098async fn handle_error(
1099 error: anyhow::Error,
1100 args: &EpArgs,
1101 command: &Command,
1102 app_state: &Arc<headless::EpAppState>,
1103 failfast_on_single_example: bool,
1104 example: &Example,
1105) {
1106 Progress::global().increment_failed();
1107
1108 let msg;
1109 if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1110 let example_name = example.spec.filename();
1111
1112 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1113 app_state
1114 .fs
1115 .write(
1116 &failed_example_path,
1117 &serde_json::to_vec_pretty(&example).unwrap(),
1118 )
1119 .await
1120 .unwrap();
1121 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1122 app_state
1123 .fs
1124 .write(&err_path, format!("{error:?}").as_bytes())
1125 .await
1126 .unwrap();
1127
1128 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1129 let mut file = OpenOptions::new()
1130 .create(true)
1131 .append(true)
1132 .open(&failed_jsonl_path)
1133 .expect("Failed to open failed.jsonl");
1134 writeln!(file, "{}", serde_json::to_string(example).unwrap())
1135 .expect("Failed to write to failed.jsonl");
1136
1137 let cursor_path = match example.repo_name() {
1138 Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1139 Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1140 };
1141 msg = format!(
1142 indoc::indoc! {"
1143 While processing \"{}\":
1144
1145 \x1b[31m{:?}\x1b[0m
1146
1147 Example: \x1b[36m{}\x1b[0m
1148 Error file: \x1b[36m{}\x1b[0m
1149 Cursor file: \x1b[36m{}\x1b[0m
1150 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1151 "},
1152 example.spec.name,
1153 error,
1154 failed_example_path.display(),
1155 err_path.display(),
1156 cursor_path.display(),
1157 command,
1158 failed_example_path.display(),
1159 );
1160 } else {
1161 msg = format!(
1162 indoc::indoc! {"
1163 While processing \"{}\":
1164
1165 \x1b[31m{:?}\x1b[0m
1166 "},
1167 example.spec.name, error
1168 );
1169 }
1170
1171 if args.failfast || failfast_on_single_example {
1172 Progress::global().finalize();
1173 panic!("{}", msg);
1174 } else {
1175 log::error!("{}", msg);
1176 }
1177}