1mod anthropic_client;
2mod distill;
3mod example;
4mod filter_languages;
5mod format_prompt;
6mod git;
7mod headless;
8mod load_project;
9mod metrics;
10mod openai_client;
11mod parse_output;
12mod paths;
13mod predict;
14mod progress;
15mod pull_examples;
16mod qa;
17mod reorder_patch;
18mod repair;
19mod retrieve_context;
20mod reversal_tracking;
21mod score;
22mod split_commit;
23mod split_dataset;
24mod synthesize;
25mod word_diff;
26use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
27use collections::HashSet;
28use edit_prediction::EditPredictionStore;
29use futures::channel::mpsc;
30use futures::{SinkExt as _, StreamExt as _};
31use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
32use zeta_prompt::ZetaVersion;
33
34use reqwest_client::ReqwestClient;
35use serde::{Deserialize, Deserializer, Serialize, Serializer};
36use std::fmt::Display;
37use std::fs::{File, OpenOptions};
38use std::hash::{Hash, Hasher};
39use std::io::{BufRead, BufReader, BufWriter, Write};
40use std::sync::Mutex;
41use std::{path::PathBuf, sync::Arc};
42
43use crate::distill::run_distill;
44use crate::example::{Example, group_examples_by_repo, read_example_files};
45use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
46use crate::format_prompt::run_format_prompt;
47use crate::load_project::run_load_project;
48use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
49use crate::predict::run_prediction;
50use crate::progress::Progress;
51use crate::retrieve_context::run_context_retrieval;
52use crate::score::run_scoring;
53use crate::split_commit::SplitCommitArgs;
54use crate::split_dataset::SplitArgs;
55use crate::synthesize::{SynthesizeConfig, run_synthesize};
56
57#[derive(Parser, Debug)]
58#[command(name = "ep")]
59struct EpArgs {
60 #[arg(long, default_value_t = false)]
61 printenv: bool,
62 #[clap(long, default_value_t = 10, global = true)]
63 max_parallelism: usize,
64 #[clap(long, global = true)]
65 limit: Option<usize>,
66 #[clap(long, global = true)]
67 offset: Option<usize>,
68 /// Filter examples by name
69 #[clap(long, global = true)]
70 name: Option<String>,
71 /// Filter examples by repository
72 #[clap(long, global = true)]
73 repo: Option<String>,
74 #[command(subcommand)]
75 command: Option<Command>,
76 #[clap(global = true, help = INPUTS_HELP)]
77 inputs: Vec<PathBuf>,
78 #[arg(long, short, global = true)]
79 output: Option<PathBuf>,
80 #[arg(long, short, global = true)]
81 in_place: bool,
82 #[arg(long, short, global = true)]
83 failfast: bool,
84 /// How to handle failed examples in output: keep them or skip them.
85 /// Failed examples are always logged to the run's failed directory.
86 #[arg(long, global = true, default_value = "keep")]
87 failed: FailedHandling,
88 /// Output as markdown files instead of JSONL. When set, -o specifies a directory
89 /// where one .md file per example will be written (named after each example).
90 #[arg(long, short, global = true)]
91 markdown: bool,
92}
93
94/// Controls whether failed examples are included in the main output.
95/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
96#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
97pub enum FailedHandling {
98 /// Include failed examples in the main output (default)
99 #[default]
100 Keep,
101 /// Exclude failed examples from the main output
102 Skip,
103 /// Skip writing files
104 SkipNoFiles,
105}
106
107const INPUTS_HELP: &str = r#"
108Inputs can be file paths or special specifiers:
109
110 path
111 Path to an example(s) file (.md, .json, or .jsonl)
112
113 captured-after:{timestamp}
114 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
115 These are examples captured via the "Capture Edit Prediction Example" action.
116
117 rejected-after:{timestamp}
118 Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
119 These are predictions that were shown to users but rejected (useful for DPO training).
120
121 rated-after:{timestamp}
122 Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
123 These are predictions that users explicitly rated as positive or negative via the
124 rate completions modal. Only zeta2 predictions are included.
125 - Positive ratings: output becomes expected_patches
126 - Negative ratings: output becomes rejected_patch
127
128 rated-positive-after:{timestamp}
129 Same as rated-after, but only fetches positively rated predictions.
130
131 rated-negative-after:{timestamp}
132 Same as rated-after, but only fetches negatively rated predictions.
133
134 Required environment variables to connect to Snowflake:
135 EP_SNOWFLAKE_API_KEY
136 EP_SNOWFLAKE_BASE_URL
137
138 Optional:
139 EP_SNOWFLAKE_ROLE
140
141Examples:
142
143 # Read examples from a file
144 ep read examples.jsonl -o output.jsonl
145
146 # Read captured examples after a timestamp
147 ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
148
149 # Read rejected predictions for DPO training
150 ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
151
152 # Read user-rated predictions
153 ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
154
155 # Read only positively rated predictions
156 ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
157
158 # Read only negatively rated predictions
159 ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
160
161 # Mix multiple input sources
162 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
163"#;
164
165#[derive(Subcommand, Debug, Clone)]
166enum Command {
167 /// Read examples from files or fetch from Snowflake, output as .jsonl
168 Read,
169 /// Create git worktrees for each example and load file contents
170 LoadProject,
171 /// Retrieve context for input examples.
172 Context,
173 /// Generate a prompt string for a specific model
174 FormatPrompt(FormatPromptArgs),
175 /// Runs edit prediction
176 Predict(PredictArgs),
177 /// Parse model outputs (actual_output) into unified diffs (actual_patch).
178 /// Requires format-prompt to have been run first. Uses provider from prompt.
179 ParseOutput,
180 /// Computes a score based on actual and expected patches
181 Score(PredictArgs),
182 /// Prepares a distillation dataset by copying expected outputs to
183 /// predicted outputs and removing actual outputs and prompts.
184 Distill,
185 /// Print aggregated scores
186 Eval(EvalArgs),
187 /// Generate eval examples by analyzing git commits from a repository
188 Synthesize(SynthesizeArgs),
189 /// Remove git repositories and worktrees
190 Clean,
191 /// Generate an evaluation example by splitting a chronologically-ordered commit
192 SplitCommit(SplitCommitArgs),
193 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
194 Split(SplitArgs),
195 /// Filter a JSONL dataset by programming language (based on cursor_path extension)
196 FilterLanguages(FilterLanguagesArgs),
197 /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
198 ImportBatch(ImportBatchArgs),
199 /// Assess the quality of predictions using LLM-as-a-judge
200 Qa(qa::QaArgs),
201 /// Repair predictions that received poor QA scores by generating improved predictions
202 Repair(repair::RepairArgs),
203}
204
205impl Display for Command {
206 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
207 match self {
208 Command::Read => write!(f, "read"),
209 Command::LoadProject => write!(f, "load-project"),
210 Command::Context => write!(f, "context"),
211 Command::FormatPrompt(args) => {
212 write!(f, "format-prompt --provider={}", args.provider)
213 }
214 Command::Predict(args) => match &args.provider {
215 Some(provider) => write!(f, "predict --provider={}", provider),
216 None => write!(f, "predict"),
217 },
218 Command::ParseOutput => write!(f, "parse-output"),
219 Command::Score(args) => match &args.provider {
220 Some(provider) => write!(f, "score --provider={}", provider),
221 None => write!(f, "score"),
222 },
223 Command::Distill => write!(f, "distill"),
224 Command::Eval(args) => match &args.predict.provider {
225 Some(provider) => write!(f, "eval --provider={}", provider),
226 None => write!(f, "eval"),
227 },
228 Command::Synthesize(args) => {
229 write!(f, "synthesize --repos {}", args.repos.join(" "))
230 }
231 Command::Clean => write!(f, "clean"),
232 Command::SplitCommit(_) => write!(f, "split-commit"),
233 Command::Split(_) => write!(f, "split"),
234 Command::FilterLanguages(_) => write!(f, "filter-languages"),
235 Command::ImportBatch(args) => {
236 write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
237 }
238 Command::Qa(_) => {
239 write!(f, "qa")
240 }
241 Command::Repair(_) => {
242 write!(f, "repair")
243 }
244 }
245 }
246}
247
248#[derive(Debug, Args, Clone)]
249struct FormatPromptArgs {
250 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
251 provider: PredictionProvider,
252}
253
254#[derive(Debug, Args, Clone)]
255struct PredictArgs {
256 #[clap(long, short('p'))]
257 provider: Option<PredictionProvider>,
258 #[clap(long, default_value_t = 1)]
259 repetitions: usize,
260 /// Only use cached responses, don't queue new requests for batching
261 #[clap(long)]
262 cache_only: bool,
263}
264
265#[derive(Debug, Args, Clone)]
266struct EvalArgs {
267 #[clap(flatten)]
268 predict: PredictArgs,
269 /// Path to write summary scores as JSON
270 #[clap(long)]
271 summary_json: Option<PathBuf>,
272}
273
274#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
275pub enum TeacherBackend {
276 Sonnet45,
277 Gpt52,
278}
279
280impl std::fmt::Display for TeacherBackend {
281 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
282 match self {
283 TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
284 TeacherBackend::Gpt52 => write!(f, "gpt52"),
285 }
286 }
287}
288
289impl std::str::FromStr for TeacherBackend {
290 type Err = anyhow::Error;
291
292 fn from_str(s: &str) -> Result<Self, Self::Err> {
293 match s.to_lowercase().as_str() {
294 "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
295 "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
296 "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
297 _ => anyhow::bail!("unknown teacher backend `{s}`. Valid options: sonnet45, gpt52"),
298 }
299 }
300}
301
302impl TeacherBackend {
303 pub fn model_name(&self) -> &'static str {
304 match self {
305 TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
306 TeacherBackend::Gpt52 => "gpt-5.2",
307 }
308 }
309}
310
311#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
312enum PredictionProvider {
313 Sweep,
314 Mercury,
315 Zeta1,
316 Zeta2(ZetaVersion),
317 Teacher(TeacherBackend),
318 TeacherNonBatching(TeacherBackend),
319 Repair,
320}
321
322impl Default for PredictionProvider {
323 fn default() -> Self {
324 PredictionProvider::Zeta2(ZetaVersion::default())
325 }
326}
327
328impl std::fmt::Display for PredictionProvider {
329 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
330 match self {
331 PredictionProvider::Sweep => write!(f, "sweep"),
332 PredictionProvider::Mercury => write!(f, "mercury"),
333 PredictionProvider::Zeta1 => write!(f, "zeta1"),
334 PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
335 PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
336 PredictionProvider::TeacherNonBatching(backend) => {
337 write!(f, "teacher-non-batching:{backend}")
338 }
339 PredictionProvider::Repair => write!(f, "repair"),
340 }
341 }
342}
343
344impl std::str::FromStr for PredictionProvider {
345 type Err = anyhow::Error;
346
347 fn from_str(s: &str) -> Result<Self, Self::Err> {
348 let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
349
350 let provider_lower = provider.to_lowercase();
351 match provider_lower.as_str() {
352 "sweep" => Ok(PredictionProvider::Sweep),
353 "mercury" => Ok(PredictionProvider::Mercury),
354 "zeta1" => Ok(PredictionProvider::Zeta1),
355 "zeta2" => {
356 let version = arg.map(ZetaVersion::parse).transpose()?.unwrap_or_default();
357 Ok(PredictionProvider::Zeta2(version))
358 }
359 "teacher" => {
360 let backend = arg
361 .map(|a| a.parse())
362 .transpose()?
363 .unwrap_or(TeacherBackend::Sonnet45);
364 Ok(PredictionProvider::Teacher(backend))
365 }
366 "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
367 let backend = arg
368 .map(|a| a.parse())
369 .transpose()?
370 .unwrap_or(TeacherBackend::Sonnet45);
371 Ok(PredictionProvider::TeacherNonBatching(backend))
372 }
373 "repair" => Ok(PredictionProvider::Repair),
374 _ => {
375 anyhow::bail!(
376 "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
377 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
378 For teacher, you can specify a backend like `teacher:sonnet45` or `teacher:gpt52`.\n\
379 Available zeta versions:\n{}",
380 ZetaVersion::options_as_string()
381 )
382 }
383 }
384 }
385}
386
387impl Serialize for PredictionProvider {
388 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
389 where
390 S: Serializer,
391 {
392 serializer.serialize_str(&self.to_string())
393 }
394}
395
396impl<'de> Deserialize<'de> for PredictionProvider {
397 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
398 where
399 D: Deserializer<'de>,
400 {
401 let s = String::deserialize(deserializer)?;
402 s.parse().map_err(serde::de::Error::custom)
403 }
404}
405
406#[derive(Debug, Args, Clone)]
407struct SynthesizeArgs {
408 /// Repository URLs (git@github.com:owner/repo or https://...)
409 #[clap(long, required = true, num_args = 1..)]
410 repos: Vec<String>,
411
412 /// Number of examples to generate per repository
413 #[clap(long, default_value_t = 5)]
414 count: usize,
415
416 /// Maximum commits to scan per repository before giving up
417 #[clap(long, default_value_t = 100)]
418 max_commits: usize,
419
420 /// Ignore state file and reprocess all commits
421 #[clap(long)]
422 fresh: bool,
423}
424
425#[derive(Debug, Args, Clone)]
426struct ImportBatchArgs {
427 /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
428 #[clap(long, required = true, num_args = 1..)]
429 batch_ids: Vec<String>,
430 /// Which provider's batches to import (anthropic or openai)
431 #[clap(long, default_value = "anthropic")]
432 provider: BatchProvider,
433}
434
435#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
436enum BatchProvider {
437 Anthropic,
438 Openai,
439}
440
441impl EpArgs {
442 fn output_path(&self) -> Option<PathBuf> {
443 if self.in_place {
444 if self.inputs.len() == 1 {
445 self.inputs.first().cloned()
446 } else {
447 panic!("--in-place requires exactly one input file")
448 }
449 } else {
450 self.output.clone()
451 }
452 }
453}
454
455async fn load_examples(
456 http_client: Arc<dyn http_client::HttpClient>,
457 args: &EpArgs,
458 output_path: Option<&PathBuf>,
459 background_executor: BackgroundExecutor,
460) -> anyhow::Result<Vec<Example>> {
461 let mut captured_after_timestamps = Vec::new();
462 let mut rejected_after_timestamps = Vec::new();
463 let mut requested_after_timestamps = Vec::new();
464 let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
465 Vec::new();
466 let mut file_inputs = Vec::new();
467
468 for input in &args.inputs {
469 let input_string = input.to_string_lossy();
470 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
471 captured_after_timestamps.push(timestamp.to_string());
472 } else if let Some(timestamp) =
473 pull_examples::parse_rejected_after_input(input_string.as_ref())
474 {
475 rejected_after_timestamps.push(timestamp.to_string());
476 } else if let Some(timestamp) =
477 pull_examples::parse_requested_after_input(input_string.as_ref())
478 {
479 requested_after_timestamps.push(timestamp.to_string());
480 } else if let Some((timestamp, rating_filter)) =
481 pull_examples::parse_rated_after_input(input_string.as_ref())
482 {
483 rated_after_inputs.push((timestamp.to_string(), rating_filter));
484 } else {
485 file_inputs.push(input.clone());
486 }
487 }
488
489 let mut examples = read_example_files(&file_inputs);
490
491 Progress::global().set_total_examples(examples.len());
492
493 let remaining_limit_for_snowflake =
494 args.limit.map(|limit| limit.saturating_sub(examples.len()));
495
496 if let Some(0) = remaining_limit_for_snowflake {
497 log::info!(
498 "skipping Snowflake inputs because --limit is already satisfied by example files"
499 );
500 } else {
501 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
502
503 if !captured_after_timestamps.is_empty() {
504 captured_after_timestamps.sort();
505
506 let mut captured_examples = pull_examples::fetch_captured_examples_after(
507 http_client.clone(),
508 &captured_after_timestamps,
509 max_rows_per_timestamp,
510 background_executor.clone(),
511 )
512 .await?;
513 examples.append(&mut captured_examples);
514 }
515
516 if !rejected_after_timestamps.is_empty() {
517 rejected_after_timestamps.sort();
518
519 let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
520 http_client.clone(),
521 &rejected_after_timestamps,
522 max_rows_per_timestamp,
523 background_executor.clone(),
524 )
525 .await?;
526 examples.append(&mut rejected_examples);
527 }
528
529 if !requested_after_timestamps.is_empty() {
530 requested_after_timestamps.sort();
531
532 let mut requested_examples = pull_examples::fetch_requested_examples_after(
533 http_client.clone(),
534 &requested_after_timestamps,
535 max_rows_per_timestamp,
536 background_executor.clone(),
537 )
538 .await?;
539 examples.append(&mut requested_examples);
540 }
541
542 if !rated_after_inputs.is_empty() {
543 rated_after_inputs.sort();
544
545 let mut rated_examples = pull_examples::fetch_rated_examples_after(
546 http_client,
547 &rated_after_inputs,
548 max_rows_per_timestamp,
549 background_executor,
550 )
551 .await?;
552 examples.append(&mut rated_examples);
553 }
554 }
555
556 crate::example::sort_examples_by_repo_and_rev(&mut examples);
557
558 if let Some(name_filter) = &args.name {
559 examples.retain(|example| example.spec.name.contains(name_filter));
560 }
561 if let Some(repo_filter) = &args.repo {
562 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
563 }
564
565 // Skip resume logic for --in-place since input and output are the same file,
566 // which would incorrectly treat all input examples as already processed.
567 if !args.in_place {
568 if let Some(path) = output_path {
569 resume_from_output(path, &mut examples);
570 }
571 }
572
573 if let Some(offset) = args.offset {
574 examples.splice(0..offset, []);
575 }
576
577 if let Some(limit) = args.limit {
578 examples.truncate(limit);
579 }
580
581 let progress = Progress::global();
582 progress.set_total_examples(examples.len());
583 progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
584
585 Ok(examples)
586}
587
588fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
589 let mut hasher = collections::FxHasher::default();
590 spec.hash(&mut hasher);
591 hasher.finish()
592}
593
594fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
595 let file = match File::open(path) {
596 Ok(f) => f,
597 Err(_) => return,
598 };
599
600 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
601
602 let reader = BufReader::new(file);
603 let mut kept_lines = Vec::new();
604 let mut kept_hashes = HashSet::default();
605
606 for line in reader.lines() {
607 let line = match line {
608 Ok(l) => l,
609 Err(_) => continue,
610 };
611
612 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
613 let hash = spec_hash(&output_example.spec);
614 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
615 kept_hashes.insert(hash);
616 kept_lines.push(line);
617 }
618 }
619 }
620
621 let total = examples.len();
622 let already_processed = kept_hashes.len();
623
624 eprintln!(
625 "Resuming: {}/{} examples already processed",
626 already_processed, total
627 );
628
629 let file = OpenOptions::new()
630 .write(true)
631 .truncate(true)
632 .open(path)
633 .expect("Failed to open output file for rewriting");
634 let mut writer = BufWriter::new(file);
635 for line in &kept_lines {
636 writeln!(writer, "{}", line).expect("Failed to write to output file");
637 }
638 writer.flush().expect("Failed to flush output file");
639
640 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
641}
642
643fn main() {
644 let args = EpArgs::parse();
645
646 if args.printenv {
647 ::util::shell_env::print_env();
648 return;
649 }
650
651 let output = args.output_path();
652
653 if args.markdown && output.is_none() {
654 eprintln!("--markdown requires -o to specify the output directory");
655 std::process::exit(1);
656 }
657
658 let command = match &args.command {
659 Some(cmd) => cmd.clone(),
660 None => {
661 EpArgs::command().print_help().unwrap();
662 return;
663 }
664 };
665
666 match &command {
667 Command::ImportBatch(import_args) => {
668 smol::block_on(async {
669 match import_args.provider {
670 BatchProvider::Anthropic => {
671 let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
672 .expect("Failed to create Anthropic client");
673 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
674 eprintln!("Error importing Anthropic batches: {:?}", e);
675 std::process::exit(1);
676 }
677 }
678 BatchProvider::Openai => {
679 let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
680 .expect("Failed to create OpenAI client");
681 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
682 eprintln!("Error importing OpenAI batches: {:?}", e);
683 std::process::exit(1);
684 }
685 }
686 }
687 println!(
688 "Successfully imported {} batch(es)",
689 import_args.batch_ids.len()
690 );
691 });
692 return;
693 }
694 Command::Clean => {
695 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
696 return;
697 }
698 Command::Synthesize(synth_args) => {
699 let Some(output_dir) = args.output else {
700 panic!("output dir is required");
701 };
702 let config = SynthesizeConfig {
703 repo_urls: synth_args.repos.clone(),
704 count: synth_args.count,
705 max_commits: synth_args.max_commits,
706 output_dir,
707 fresh: synth_args.fresh,
708 };
709 smol::block_on(async {
710 if let Err(e) = run_synthesize(config).await {
711 eprintln!("Error: {:?}", e);
712 std::process::exit(1);
713 }
714 });
715 return;
716 }
717 Command::SplitCommit(split_commit_args) => {
718 if let Err(error) = split_commit::run_split_commit(
719 split_commit_args,
720 &args.inputs,
721 output.as_ref(),
722 args.failed,
723 ) {
724 eprintln!("{error:#}");
725 std::process::exit(1);
726 }
727 return;
728 }
729 Command::Split(split_args) => {
730 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
731 eprintln!("{error:#}");
732 std::process::exit(1);
733 }
734 return;
735 }
736 Command::FilterLanguages(filter_args) => {
737 if let Err(error) =
738 run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
739 {
740 eprintln!("{error:#}");
741 std::process::exit(1);
742 }
743 return;
744 }
745 Command::Qa(qa_args) => {
746 // Read examples from input files
747 let mut examples = example::read_example_files(&args.inputs);
748
749 // Apply filters
750 if let Some(name_filter) = &args.name {
751 examples.retain(|e| e.spec.name.contains(name_filter));
752 }
753 if let Some(repo_filter) = &args.repo {
754 examples.retain(|e| e.spec.repository_url.contains(repo_filter));
755 }
756 if let Some(offset) = args.offset {
757 examples.splice(0..offset, []);
758 }
759 if let Some(limit) = args.limit {
760 examples.truncate(limit);
761 }
762
763 smol::block_on(async {
764 if let Err(e) = qa::run_qa(&mut examples, qa_args, output.as_ref()).await {
765 eprintln!("Error: {:?}", e);
766 std::process::exit(1);
767 }
768 });
769 return;
770 }
771 Command::Repair(repair_args) => {
772 // Read examples from input files
773 let mut examples = example::read_example_files(&args.inputs);
774
775 // Apply filters
776 if let Some(name_filter) = &args.name {
777 examples.retain(|e| e.spec.name.contains(name_filter));
778 }
779 if let Some(repo_filter) = &args.repo {
780 examples.retain(|e| e.spec.repository_url.contains(repo_filter));
781 }
782 if let Some(offset) = args.offset {
783 examples.splice(0..offset, []);
784 }
785 if let Some(limit) = args.limit {
786 examples.truncate(limit);
787 }
788
789 smol::block_on(async {
790 if let Err(e) =
791 repair::run_repair(&mut examples, repair_args, output.as_ref()).await
792 {
793 eprintln!("Error: {:?}", e);
794 std::process::exit(1);
795 }
796 });
797 return;
798 }
799 _ => {}
800 }
801
802 let http_client = Arc::new(ReqwestClient::new());
803 let app = Application::headless().with_http_client(http_client);
804
805 app.run(move |cx| {
806 let app_state = Arc::new(headless::init(cx));
807 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
808
809 cx.spawn(async move |cx| {
810 let result = async {
811 let examples = load_examples(
812 app_state.client.http_client(),
813 &args,
814 output.as_ref(),
815 cx.background_executor().clone(),
816 )
817 .await?;
818
819 match &command {
820 Command::Predict(args) | Command::Score(args) => {
821 predict::sync_batches(args.provider.as_ref()).await?;
822 }
823 Command::Eval(args) => {
824 predict::sync_batches(args.predict.provider.as_ref()).await?;
825 }
826 _ => (),
827 }
828
829 let failfast_on_single_example = examples.len() == 1;
830
831 // For --markdown mode, create the output directory if it doesn't exist
832 if args.markdown {
833 let dir = output.as_ref().expect("--markdown requires -o");
834 if !dir.exists() {
835 std::fs::create_dir_all(dir)
836 .expect("Failed to create markdown output directory");
837 }
838 }
839
840 // Set up JSONL output writer (not used in markdown mode)
841 let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
842 let mut in_place_temp_path: Option<PathBuf> = None;
843 if !args.markdown
844 && let Some(output_path) = output.as_ref()
845 {
846 let write_path = if args.in_place {
847 let temp = output_path.with_extension("jsonl.tmp");
848 in_place_temp_path = Some(temp.clone());
849 temp
850 } else {
851 output_path.clone()
852 };
853
854 let file = OpenOptions::new()
855 .create(true)
856 .write(true)
857 .truncate(args.in_place)
858 .append(!args.in_place)
859 .open(&write_path)
860 .expect("Failed to open output file");
861
862 let mut writer = BufWriter::new(file);
863 let (sender, mut receiver) = mpsc::unbounded::<String>();
864 cx.background_spawn(async move {
865 while let Some(line) = receiver.next().await {
866 writeln!(writer, "{}", line).expect("Failed to write example");
867 writer.flush().expect("Failed to flush output");
868 }
869 })
870 .detach();
871 output_sender = Some(sender);
872 }
873
874 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
875 let finished_examples = Mutex::new(Vec::new());
876
877 let mut tasks = Vec::new();
878 for _ in 0..args.max_parallelism {
879 tasks.push(async {
880 loop {
881 let Some(mut repo_examples) =
882 grouped_examples.lock().unwrap().pop_front()
883 else {
884 break;
885 };
886 for example in &mut repo_examples {
887 let example_progress =
888 Progress::global().start_group(&example.spec.name);
889
890 let result = async {
891 match &command {
892 Command::Read => {}
893 Command::LoadProject => {
894 run_load_project(
895 example,
896 app_state.clone(),
897 &example_progress,
898 cx.clone(),
899 )
900 .await?;
901 }
902 Command::Context => {
903 run_context_retrieval(
904 example,
905 app_state.clone(),
906 &example_progress,
907 cx.clone(),
908 )
909 .await?;
910 }
911 Command::FormatPrompt(args) => {
912 run_format_prompt(
913 example,
914 args,
915 app_state.clone(),
916 &example_progress,
917 cx.clone(),
918 )
919 .await?;
920 }
921 Command::Predict(args) => {
922 run_prediction(
923 example,
924 args,
925 app_state.clone(),
926 &example_progress,
927 cx.clone(),
928 )
929 .await?;
930 }
931 Command::ParseOutput => {
932 parse_output::run_parse_output(example)?;
933 }
934 Command::Distill => {
935 run_distill(example).await?;
936 }
937 Command::Score(args) => {
938 run_scoring(
939 example,
940 args,
941 app_state.clone(),
942 &example_progress,
943 cx.clone(),
944 )
945 .await?;
946 }
947 Command::Eval(args) => {
948 run_scoring(
949 example,
950 &args.predict,
951 app_state.clone(),
952 &example_progress,
953 cx.clone(),
954 )
955 .await?;
956 }
957 Command::Clean
958 | Command::Synthesize(_)
959 | Command::SplitCommit(_)
960 | Command::Split(_)
961 | Command::FilterLanguages(_)
962 | Command::ImportBatch(_)
963 | Command::Qa(_)
964 | Command::Repair(_) => {
965 unreachable!()
966 }
967 }
968 anyhow::Ok(())
969 }
970 .await;
971
972 let failed = if let Err(error) = result {
973 handle_error(
974 error,
975 &args,
976 &command,
977 &app_state,
978 failfast_on_single_example,
979 &example,
980 )
981 .await;
982 true
983 } else {
984 false
985 };
986
987 let should_write = !failed || args.failed == FailedHandling::Keep;
988 if should_write {
989 if args.markdown {
990 let markdown_dir =
991 output.as_ref().expect("--markdown requires -o");
992 let filename = format!("{}.md", example.spec.filename());
993 let path = markdown_dir.join(&filename);
994 let markdown = example.spec.to_markdown();
995 std::fs::write(&path, &markdown)
996 .expect("Failed to write markdown file");
997 } else if let Some(ref mut sender) = output_sender.clone() {
998 let line = serde_json::to_string(&example).unwrap();
999 sender
1000 .send(line)
1001 .await
1002 .expect("Failed to send to output writer");
1003 } else if args.output.is_none()
1004 && !matches!(command, Command::Eval(_))
1005 {
1006 let line = serde_json::to_string(&example).unwrap();
1007 println!("{}", line);
1008 }
1009 }
1010 }
1011
1012 let repo_url = &repo_examples.first().unwrap().spec.repository_url;
1013 let project = repo_examples
1014 .iter()
1015 .find_map(|e| e.state.as_ref().map(|s| s.project.clone()))
1016 .or_else(|| app_state.project_cache.get(repo_url));
1017
1018 if let Some(project) = project {
1019 let mut cx = cx.clone();
1020
1021 let shutdown_task: Task<()> =
1022 project.update(&mut cx, |project, cx| {
1023 let lsp_store = project.lsp_store();
1024 lsp_store.update(cx, |lsp_store, cx| {
1025 lsp_store.shutdown_all_language_servers(cx)
1026 })
1027 });
1028
1029 shutdown_task.await;
1030
1031 if let Some(ep_store) =
1032 cx.update(|cx| EditPredictionStore::try_global(cx))
1033 {
1034 ep_store.update(&mut cx, |store, _| {
1035 store.remove_project(&project);
1036 });
1037 }
1038 }
1039
1040 app_state.project_cache.remove(repo_url);
1041 for example in &mut repo_examples {
1042 example.state.take();
1043 }
1044 finished_examples
1045 .lock()
1046 .unwrap()
1047 .extend_from_slice(&repo_examples);
1048 }
1049 });
1050 }
1051 futures::future::join_all(tasks).await;
1052
1053 Progress::global().finalize();
1054
1055 match &command {
1056 Command::Predict(args) | Command::Score(args) => {
1057 predict::sync_batches(args.provider.as_ref()).await?;
1058 }
1059 Command::Eval(args) => {
1060 predict::sync_batches(args.predict.provider.as_ref()).await?;
1061 }
1062 _ => (),
1063 }
1064
1065 match &command {
1066 Command::Eval(args) => {
1067 let examples = finished_examples.lock().unwrap();
1068 score::print_report(&examples);
1069 if let Some(summary_path) = &args.summary_json {
1070 score::write_summary_json(&examples, summary_path)?;
1071 }
1072 }
1073 _ => (),
1074 };
1075
1076 // For --in-place, atomically rename temp file to original
1077 if let Some(temp_path) = &in_place_temp_path {
1078 let final_path = output.as_ref().expect("in_place_temp_path requires output");
1079 std::fs::rename(temp_path, final_path)
1080 .expect("Failed to rename temp file to final output");
1081 }
1082
1083 anyhow::Ok(())
1084 }
1085 .await;
1086
1087 if let Err(e) = result {
1088 panic!("Fatal error: {:?}", e);
1089 }
1090
1091 let _ = cx.update(|cx| cx.quit());
1092 })
1093 .detach();
1094 });
1095}
1096
1097async fn handle_error(
1098 error: anyhow::Error,
1099 args: &EpArgs,
1100 command: &Command,
1101 app_state: &Arc<headless::EpAppState>,
1102 failfast_on_single_example: bool,
1103 example: &Example,
1104) {
1105 Progress::global().increment_failed();
1106
1107 let msg;
1108 if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1109 let example_name = example.spec.filename();
1110
1111 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1112 app_state
1113 .fs
1114 .write(
1115 &failed_example_path,
1116 &serde_json::to_vec_pretty(&example).unwrap(),
1117 )
1118 .await
1119 .unwrap();
1120 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1121 app_state
1122 .fs
1123 .write(&err_path, format!("{error:?}").as_bytes())
1124 .await
1125 .unwrap();
1126
1127 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1128 let mut file = OpenOptions::new()
1129 .create(true)
1130 .append(true)
1131 .open(&failed_jsonl_path)
1132 .expect("Failed to open failed.jsonl");
1133 writeln!(file, "{}", serde_json::to_string(example).unwrap())
1134 .expect("Failed to write to failed.jsonl");
1135
1136 let cursor_path = match example.repo_name() {
1137 Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1138 Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1139 };
1140 msg = format!(
1141 indoc::indoc! {"
1142 While processing \"{}\":
1143
1144 \x1b[31m{:?}\x1b[0m
1145
1146 Example: \x1b[36m{}\x1b[0m
1147 Error file: \x1b[36m{}\x1b[0m
1148 Cursor file: \x1b[36m{}\x1b[0m
1149 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1150 "},
1151 example.spec.name,
1152 error,
1153 failed_example_path.display(),
1154 err_path.display(),
1155 cursor_path.display(),
1156 command,
1157 failed_example_path.display(),
1158 );
1159 } else {
1160 msg = format!(
1161 indoc::indoc! {"
1162 While processing \"{}\":
1163
1164 \x1b[31m{:?}\x1b[0m
1165 "},
1166 example.spec.name, error
1167 );
1168 }
1169
1170 if args.failfast || failfast_on_single_example {
1171 Progress::global().finalize();
1172 panic!("{}", msg);
1173 } else {
1174 log::error!("{}", msg);
1175 }
1176}