1mod anthropic_client;
2mod distill;
3mod example;
4mod filter_languages;
5mod format_prompt;
6mod git;
7mod headless;
8mod load_project;
9mod metrics;
10mod openai_client;
11mod parse_output;
12mod paths;
13mod predict;
14mod progress;
15mod pull_examples;
16mod qa;
17mod reorder_patch;
18mod repair;
19mod retrieve_context;
20mod reversal_tracking;
21mod score;
22mod split_commit;
23mod split_dataset;
24mod synthesize;
25mod word_diff;
26use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
27use collections::HashSet;
28use edit_prediction::EditPredictionStore;
29use futures::channel::mpsc;
30use futures::{SinkExt as _, StreamExt as _};
31use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
32use zeta_prompt::ZetaVersion;
33
34use reqwest_client::ReqwestClient;
35use serde::{Deserialize, Deserializer, Serialize, Serializer};
36use std::fmt::Display;
37use std::fs::{File, OpenOptions};
38use std::hash::{Hash, Hasher};
39use std::io::{BufRead, BufReader, BufWriter, Write};
40use std::sync::Mutex;
41use std::{path::PathBuf, sync::Arc};
42
43use crate::distill::run_distill;
44use crate::example::{Example, group_examples_by_repo, read_example_files};
45use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
46use crate::format_prompt::run_format_prompt;
47use crate::load_project::run_load_project;
48use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
49use crate::predict::run_prediction;
50use crate::progress::Progress;
51use crate::retrieve_context::run_context_retrieval;
52use crate::score::run_scoring;
53use crate::split_commit::SplitCommitArgs;
54use crate::split_dataset::SplitArgs;
55use crate::synthesize::{SynthesizeConfig, run_synthesize};
56
57#[derive(Parser, Debug)]
58#[command(name = "ep")]
59struct EpArgs {
60 #[arg(long, default_value_t = false)]
61 printenv: bool,
62 #[clap(long, default_value_t = 10, global = true)]
63 max_parallelism: usize,
64 #[clap(long, global = true)]
65 limit: Option<usize>,
66 #[clap(long, global = true)]
67 offset: Option<usize>,
68 /// Filter examples by name
69 #[clap(long, global = true)]
70 name: Option<String>,
71 /// Filter examples by repository
72 #[clap(long, global = true)]
73 repo: Option<String>,
74 #[command(subcommand)]
75 command: Option<Command>,
76 #[clap(global = true, help = INPUTS_HELP)]
77 inputs: Vec<PathBuf>,
78 #[arg(long, short, global = true)]
79 output: Option<PathBuf>,
80 #[arg(long, short, global = true)]
81 in_place: bool,
82 #[arg(long, short, global = true)]
83 failfast: bool,
84 /// How to handle failed examples in output: keep them or skip them.
85 /// Failed examples are always logged to the run's failed directory.
86 #[arg(long, global = true, default_value = "keep")]
87 failed: FailedHandling,
88 /// Output as markdown files instead of JSONL. When set, -o specifies a directory
89 /// where one .md file per example will be written (named after each example).
90 #[arg(long, short, global = true)]
91 markdown: bool,
92}
93
94/// Controls whether failed examples are included in the main output.
95/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
96#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
97pub enum FailedHandling {
98 /// Include failed examples in the main output (default)
99 #[default]
100 Keep,
101 /// Exclude failed examples from the main output
102 Skip,
103 /// Skip writing files
104 SkipNoFiles,
105}
106
107const INPUTS_HELP: &str = r#"
108Inputs can be file paths or special specifiers:
109
110 path
111 Path to an example(s) file (.md, .json, or .jsonl)
112
113 captured-after:{timestamp}
114 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
115 These are examples captured via the "Capture Edit Prediction Example" action.
116
117 rejected-after:{timestamp}
118 Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
119 These are predictions that were shown to users but rejected (useful for DPO training).
120
121 rated-after:{timestamp}
122 Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
123 These are predictions that users explicitly rated as positive or negative via the
124 rate completions modal. Only zeta2 predictions are included.
125 - Positive ratings: output becomes expected_patches
126 - Negative ratings: output becomes rejected_patch
127
128 rated-positive-after:{timestamp}
129 Same as rated-after, but only fetches positively rated predictions.
130
131 rated-negative-after:{timestamp}
132 Same as rated-after, but only fetches negatively rated predictions.
133
134 Required environment variables to connect to Snowflake:
135 EP_SNOWFLAKE_API_KEY
136 EP_SNOWFLAKE_BASE_URL
137
138 Optional:
139 EP_SNOWFLAKE_ROLE
140
141Examples:
142
143 # Read examples from a file
144 ep read examples.jsonl -o output.jsonl
145
146 # Read captured examples after a timestamp
147 ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
148
149 # Read rejected predictions for DPO training
150 ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
151
152 # Read user-rated predictions
153 ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
154
155 # Read only positively rated predictions
156 ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
157
158 # Read only negatively rated predictions
159 ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
160
161 # Mix multiple input sources
162 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
163"#;
164
165#[derive(Subcommand, Debug, Clone)]
166enum Command {
167 /// Read examples from files or fetch from Snowflake, output as .jsonl
168 Read,
169 /// Create git worktrees for each example and load file contents
170 LoadProject,
171 /// Retrieve context for input examples.
172 Context,
173 /// Generate a prompt string for a specific model
174 FormatPrompt(FormatPromptArgs),
175 /// Runs edit prediction
176 Predict(PredictArgs),
177 /// Parse model outputs (actual_output) into unified diffs (actual_patch).
178 /// Requires format-prompt to have been run first. Uses provider from prompt.
179 ParseOutput,
180 /// Computes a score based on actual and expected patches
181 Score(PredictArgs),
182 /// Prepares a distillation dataset by copying expected outputs to
183 /// predicted outputs and removing actual outputs and prompts.
184 Distill,
185 /// Print aggregated scores
186 Eval(EvalArgs),
187 /// Generate eval examples by analyzing git commits from a repository
188 Synthesize(SynthesizeArgs),
189 /// Remove git repositories and worktrees
190 Clean,
191 /// Generate an evaluation example by splitting a chronologically-ordered commit
192 SplitCommit(SplitCommitArgs),
193 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
194 Split(SplitArgs),
195 /// Filter a JSONL dataset by programming language (based on cursor_path extension)
196 FilterLanguages(FilterLanguagesArgs),
197 /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
198 ImportBatch(ImportBatchArgs),
199 /// Assess the quality of predictions using LLM-as-a-judge
200 Qa(qa::QaArgs),
201 /// Repair predictions that received poor QA scores by generating improved predictions
202 Repair(repair::RepairArgs),
203}
204
205impl Display for Command {
206 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
207 match self {
208 Command::Read => write!(f, "read"),
209 Command::LoadProject => write!(f, "load-project"),
210 Command::Context => write!(f, "context"),
211 Command::FormatPrompt(args) => {
212 write!(f, "format-prompt --provider={}", args.provider)
213 }
214 Command::Predict(args) => match &args.provider {
215 Some(provider) => write!(f, "predict --provider={}", provider),
216 None => write!(f, "predict"),
217 },
218 Command::ParseOutput => write!(f, "parse-output"),
219 Command::Score(args) => match &args.provider {
220 Some(provider) => write!(f, "score --provider={}", provider),
221 None => write!(f, "score"),
222 },
223 Command::Distill => write!(f, "distill"),
224 Command::Eval(args) => match &args.predict.provider {
225 Some(provider) => write!(f, "eval --provider={}", provider),
226 None => write!(f, "eval"),
227 },
228 Command::Synthesize(args) => {
229 write!(f, "synthesize --repos {}", args.repos.join(" "))
230 }
231 Command::Clean => write!(f, "clean"),
232 Command::SplitCommit(_) => write!(f, "split-commit"),
233 Command::Split(_) => write!(f, "split"),
234 Command::FilterLanguages(_) => write!(f, "filter-languages"),
235 Command::ImportBatch(args) => {
236 write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
237 }
238 Command::Qa(_) => {
239 write!(f, "qa")
240 }
241 Command::Repair(_) => {
242 write!(f, "repair")
243 }
244 }
245 }
246}
247
248#[derive(Debug, Args, Clone)]
249struct FormatPromptArgs {
250 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
251 provider: PredictionProvider,
252}
253
254#[derive(Debug, Args, Clone)]
255struct PredictArgs {
256 #[clap(long, short('p'))]
257 provider: Option<PredictionProvider>,
258 #[clap(long, default_value_t = 1)]
259 repetitions: usize,
260}
261
262#[derive(Debug, Args, Clone)]
263struct EvalArgs {
264 #[clap(flatten)]
265 predict: PredictArgs,
266 /// Path to write summary scores as JSON
267 #[clap(long)]
268 summary_json: Option<PathBuf>,
269}
270
271#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
272pub enum TeacherBackend {
273 Sonnet45,
274 Gpt52,
275}
276
277impl std::fmt::Display for TeacherBackend {
278 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
279 match self {
280 TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
281 TeacherBackend::Gpt52 => write!(f, "gpt52"),
282 }
283 }
284}
285
286impl std::str::FromStr for TeacherBackend {
287 type Err = anyhow::Error;
288
289 fn from_str(s: &str) -> Result<Self, Self::Err> {
290 match s.to_lowercase().as_str() {
291 "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
292 "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
293 "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
294 _ => anyhow::bail!("unknown teacher backend `{s}`. Valid options: sonnet45, gpt52"),
295 }
296 }
297}
298
299impl TeacherBackend {
300 pub fn model_name(&self) -> &'static str {
301 match self {
302 TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
303 TeacherBackend::Gpt52 => "gpt-5.2",
304 }
305 }
306}
307
308#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
309enum PredictionProvider {
310 Sweep,
311 Mercury,
312 Zeta1,
313 Zeta2(ZetaVersion),
314 Teacher(TeacherBackend),
315 TeacherNonBatching(TeacherBackend),
316 Repair,
317}
318
319impl Default for PredictionProvider {
320 fn default() -> Self {
321 PredictionProvider::Zeta2(ZetaVersion::default())
322 }
323}
324
325impl std::fmt::Display for PredictionProvider {
326 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
327 match self {
328 PredictionProvider::Sweep => write!(f, "sweep"),
329 PredictionProvider::Mercury => write!(f, "mercury"),
330 PredictionProvider::Zeta1 => write!(f, "zeta1"),
331 PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
332 PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
333 PredictionProvider::TeacherNonBatching(backend) => {
334 write!(f, "teacher-non-batching:{backend}")
335 }
336 PredictionProvider::Repair => write!(f, "repair"),
337 }
338 }
339}
340
341impl std::str::FromStr for PredictionProvider {
342 type Err = anyhow::Error;
343
344 fn from_str(s: &str) -> Result<Self, Self::Err> {
345 let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
346
347 let provider_lower = provider.to_lowercase();
348 match provider_lower.as_str() {
349 "sweep" => Ok(PredictionProvider::Sweep),
350 "mercury" => Ok(PredictionProvider::Mercury),
351 "zeta1" => Ok(PredictionProvider::Zeta1),
352 "zeta2" => {
353 let version = arg.map(ZetaVersion::parse).transpose()?.unwrap_or_default();
354 Ok(PredictionProvider::Zeta2(version))
355 }
356 "teacher" => {
357 let backend = arg
358 .map(|a| a.parse())
359 .transpose()?
360 .unwrap_or(TeacherBackend::Sonnet45);
361 Ok(PredictionProvider::Teacher(backend))
362 }
363 "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
364 let backend = arg
365 .map(|a| a.parse())
366 .transpose()?
367 .unwrap_or(TeacherBackend::Sonnet45);
368 Ok(PredictionProvider::TeacherNonBatching(backend))
369 }
370 "repair" => Ok(PredictionProvider::Repair),
371 _ => {
372 anyhow::bail!(
373 "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
374 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
375 For teacher, you can specify a backend like `teacher:sonnet45` or `teacher:gpt52`.\n\
376 Available zeta versions:\n{}",
377 ZetaVersion::options_as_string()
378 )
379 }
380 }
381 }
382}
383
384impl Serialize for PredictionProvider {
385 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
386 where
387 S: Serializer,
388 {
389 serializer.serialize_str(&self.to_string())
390 }
391}
392
393impl<'de> Deserialize<'de> for PredictionProvider {
394 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
395 where
396 D: Deserializer<'de>,
397 {
398 let s = String::deserialize(deserializer)?;
399 s.parse().map_err(serde::de::Error::custom)
400 }
401}
402
403#[derive(Debug, Args, Clone)]
404struct SynthesizeArgs {
405 /// Repository URLs (git@github.com:owner/repo or https://...)
406 #[clap(long, required = true, num_args = 1..)]
407 repos: Vec<String>,
408
409 /// Number of examples to generate per repository
410 #[clap(long, default_value_t = 5)]
411 count: usize,
412
413 /// Maximum commits to scan per repository before giving up
414 #[clap(long, default_value_t = 100)]
415 max_commits: usize,
416
417 /// Ignore state file and reprocess all commits
418 #[clap(long)]
419 fresh: bool,
420}
421
422#[derive(Debug, Args, Clone)]
423struct ImportBatchArgs {
424 /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
425 #[clap(long, required = true, num_args = 1..)]
426 batch_ids: Vec<String>,
427 /// Which provider's batches to import (anthropic or openai)
428 #[clap(long, default_value = "anthropic")]
429 provider: BatchProvider,
430}
431
432#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
433enum BatchProvider {
434 Anthropic,
435 Openai,
436}
437
438impl EpArgs {
439 fn output_path(&self) -> Option<PathBuf> {
440 if self.in_place {
441 if self.inputs.len() == 1 {
442 self.inputs.first().cloned()
443 } else {
444 panic!("--in-place requires exactly one input file")
445 }
446 } else {
447 self.output.clone()
448 }
449 }
450}
451
452async fn load_examples(
453 http_client: Arc<dyn http_client::HttpClient>,
454 args: &EpArgs,
455 output_path: Option<&PathBuf>,
456 background_executor: BackgroundExecutor,
457) -> anyhow::Result<Vec<Example>> {
458 let mut captured_after_timestamps = Vec::new();
459 let mut rejected_after_timestamps = Vec::new();
460 let mut requested_after_timestamps = Vec::new();
461 let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
462 Vec::new();
463 let mut file_inputs = Vec::new();
464
465 for input in &args.inputs {
466 let input_string = input.to_string_lossy();
467 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
468 captured_after_timestamps.push(timestamp.to_string());
469 } else if let Some(timestamp) =
470 pull_examples::parse_rejected_after_input(input_string.as_ref())
471 {
472 rejected_after_timestamps.push(timestamp.to_string());
473 } else if let Some(timestamp) =
474 pull_examples::parse_requested_after_input(input_string.as_ref())
475 {
476 requested_after_timestamps.push(timestamp.to_string());
477 } else if let Some((timestamp, rating_filter)) =
478 pull_examples::parse_rated_after_input(input_string.as_ref())
479 {
480 rated_after_inputs.push((timestamp.to_string(), rating_filter));
481 } else {
482 file_inputs.push(input.clone());
483 }
484 }
485
486 let mut examples = read_example_files(&file_inputs);
487
488 Progress::global().set_total_examples(examples.len());
489
490 let remaining_limit_for_snowflake =
491 args.limit.map(|limit| limit.saturating_sub(examples.len()));
492
493 if let Some(0) = remaining_limit_for_snowflake {
494 log::info!(
495 "skipping Snowflake inputs because --limit is already satisfied by example files"
496 );
497 } else {
498 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
499
500 if !captured_after_timestamps.is_empty() {
501 captured_after_timestamps.sort();
502
503 let mut captured_examples = pull_examples::fetch_captured_examples_after(
504 http_client.clone(),
505 &captured_after_timestamps,
506 max_rows_per_timestamp,
507 background_executor.clone(),
508 )
509 .await?;
510 examples.append(&mut captured_examples);
511 }
512
513 if !rejected_after_timestamps.is_empty() {
514 rejected_after_timestamps.sort();
515
516 let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
517 http_client.clone(),
518 &rejected_after_timestamps,
519 max_rows_per_timestamp,
520 background_executor.clone(),
521 )
522 .await?;
523 examples.append(&mut rejected_examples);
524 }
525
526 if !requested_after_timestamps.is_empty() {
527 requested_after_timestamps.sort();
528
529 let mut requested_examples = pull_examples::fetch_requested_examples_after(
530 http_client.clone(),
531 &requested_after_timestamps,
532 max_rows_per_timestamp,
533 background_executor.clone(),
534 )
535 .await?;
536 examples.append(&mut requested_examples);
537 }
538
539 if !rated_after_inputs.is_empty() {
540 rated_after_inputs.sort();
541
542 let mut rated_examples = pull_examples::fetch_rated_examples_after(
543 http_client,
544 &rated_after_inputs,
545 max_rows_per_timestamp,
546 background_executor,
547 )
548 .await?;
549 examples.append(&mut rated_examples);
550 }
551 }
552
553 crate::example::sort_examples_by_repo_and_rev(&mut examples);
554
555 if let Some(name_filter) = &args.name {
556 examples.retain(|example| example.spec.name.contains(name_filter));
557 }
558 if let Some(repo_filter) = &args.repo {
559 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
560 }
561
562 // Skip resume logic for --in-place since input and output are the same file,
563 // which would incorrectly treat all input examples as already processed.
564 if !args.in_place {
565 if let Some(path) = output_path {
566 resume_from_output(path, &mut examples);
567 }
568 }
569
570 if let Some(offset) = args.offset {
571 examples.splice(0..offset, []);
572 }
573
574 if let Some(limit) = args.limit {
575 examples.truncate(limit);
576 }
577
578 let progress = Progress::global();
579 progress.set_total_examples(examples.len());
580 progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
581
582 Ok(examples)
583}
584
585fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
586 let mut hasher = collections::FxHasher::default();
587 spec.hash(&mut hasher);
588 hasher.finish()
589}
590
591fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
592 let file = match File::open(path) {
593 Ok(f) => f,
594 Err(_) => return,
595 };
596
597 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
598
599 let reader = BufReader::new(file);
600 let mut kept_lines = Vec::new();
601 let mut kept_hashes = HashSet::default();
602
603 for line in reader.lines() {
604 let line = match line {
605 Ok(l) => l,
606 Err(_) => continue,
607 };
608
609 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
610 let hash = spec_hash(&output_example.spec);
611 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
612 kept_hashes.insert(hash);
613 kept_lines.push(line);
614 }
615 }
616 }
617
618 let total = examples.len();
619 let already_processed = kept_hashes.len();
620
621 eprintln!(
622 "Resuming: {}/{} examples already processed",
623 already_processed, total
624 );
625
626 let file = OpenOptions::new()
627 .write(true)
628 .truncate(true)
629 .open(path)
630 .expect("Failed to open output file for rewriting");
631 let mut writer = BufWriter::new(file);
632 for line in &kept_lines {
633 writeln!(writer, "{}", line).expect("Failed to write to output file");
634 }
635 writer.flush().expect("Failed to flush output file");
636
637 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
638}
639
640fn main() {
641 let args = EpArgs::parse();
642
643 if args.printenv {
644 ::util::shell_env::print_env();
645 return;
646 }
647
648 let output = args.output_path();
649
650 if args.markdown && output.is_none() {
651 eprintln!("--markdown requires -o to specify the output directory");
652 std::process::exit(1);
653 }
654
655 let command = match &args.command {
656 Some(cmd) => cmd.clone(),
657 None => {
658 EpArgs::command().print_help().unwrap();
659 return;
660 }
661 };
662
663 match &command {
664 Command::ImportBatch(import_args) => {
665 smol::block_on(async {
666 match import_args.provider {
667 BatchProvider::Anthropic => {
668 let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
669 .expect("Failed to create Anthropic client");
670 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
671 eprintln!("Error importing Anthropic batches: {:?}", e);
672 std::process::exit(1);
673 }
674 }
675 BatchProvider::Openai => {
676 let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
677 .expect("Failed to create OpenAI client");
678 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
679 eprintln!("Error importing OpenAI batches: {:?}", e);
680 std::process::exit(1);
681 }
682 }
683 }
684 println!(
685 "Successfully imported {} batch(es)",
686 import_args.batch_ids.len()
687 );
688 });
689 return;
690 }
691 Command::Clean => {
692 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
693 return;
694 }
695 Command::Synthesize(synth_args) => {
696 let Some(output_dir) = args.output else {
697 panic!("output dir is required");
698 };
699 let config = SynthesizeConfig {
700 repo_urls: synth_args.repos.clone(),
701 count: synth_args.count,
702 max_commits: synth_args.max_commits,
703 output_dir,
704 fresh: synth_args.fresh,
705 };
706 smol::block_on(async {
707 if let Err(e) = run_synthesize(config).await {
708 eprintln!("Error: {:?}", e);
709 std::process::exit(1);
710 }
711 });
712 return;
713 }
714 Command::SplitCommit(split_commit_args) => {
715 if let Err(error) = split_commit::run_split_commit(
716 split_commit_args,
717 &args.inputs,
718 output.as_ref(),
719 args.failed,
720 ) {
721 eprintln!("{error:#}");
722 std::process::exit(1);
723 }
724 return;
725 }
726 Command::Split(split_args) => {
727 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
728 eprintln!("{error:#}");
729 std::process::exit(1);
730 }
731 return;
732 }
733 Command::FilterLanguages(filter_args) => {
734 if let Err(error) =
735 run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
736 {
737 eprintln!("{error:#}");
738 std::process::exit(1);
739 }
740 return;
741 }
742 Command::Qa(qa_args) => {
743 // Read examples from input files
744 let mut examples = example::read_example_files(&args.inputs);
745
746 // Apply filters
747 if let Some(name_filter) = &args.name {
748 examples.retain(|e| e.spec.name.contains(name_filter));
749 }
750 if let Some(repo_filter) = &args.repo {
751 examples.retain(|e| e.spec.repository_url.contains(repo_filter));
752 }
753 if let Some(offset) = args.offset {
754 examples.splice(0..offset, []);
755 }
756 if let Some(limit) = args.limit {
757 examples.truncate(limit);
758 }
759
760 smol::block_on(async {
761 if let Err(e) = qa::run_qa(&mut examples, qa_args, output.as_ref()).await {
762 eprintln!("Error: {:?}", e);
763 std::process::exit(1);
764 }
765 });
766 return;
767 }
768 Command::Repair(repair_args) => {
769 // Read examples from input files
770 let mut examples = example::read_example_files(&args.inputs);
771
772 // Apply filters
773 if let Some(name_filter) = &args.name {
774 examples.retain(|e| e.spec.name.contains(name_filter));
775 }
776 if let Some(repo_filter) = &args.repo {
777 examples.retain(|e| e.spec.repository_url.contains(repo_filter));
778 }
779 if let Some(offset) = args.offset {
780 examples.splice(0..offset, []);
781 }
782 if let Some(limit) = args.limit {
783 examples.truncate(limit);
784 }
785
786 smol::block_on(async {
787 if let Err(e) =
788 repair::run_repair(&mut examples, repair_args, output.as_ref()).await
789 {
790 eprintln!("Error: {:?}", e);
791 std::process::exit(1);
792 }
793 });
794 return;
795 }
796 _ => {}
797 }
798
799 let http_client = Arc::new(ReqwestClient::new());
800 let app = Application::headless().with_http_client(http_client);
801
802 app.run(move |cx| {
803 let app_state = Arc::new(headless::init(cx));
804 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
805
806 cx.spawn(async move |cx| {
807 let result = async {
808 let examples = load_examples(
809 app_state.client.http_client(),
810 &args,
811 output.as_ref(),
812 cx.background_executor().clone(),
813 )
814 .await?;
815
816 match &command {
817 Command::Predict(args) | Command::Score(args) => {
818 predict::sync_batches(args.provider.as_ref()).await?;
819 }
820 Command::Eval(args) => {
821 predict::sync_batches(args.predict.provider.as_ref()).await?;
822 }
823 _ => (),
824 }
825
826 let failfast_on_single_example = examples.len() == 1;
827
828 // For --markdown mode, create the output directory if it doesn't exist
829 if args.markdown {
830 let dir = output.as_ref().expect("--markdown requires -o");
831 if !dir.exists() {
832 std::fs::create_dir_all(dir)
833 .expect("Failed to create markdown output directory");
834 }
835 }
836
837 // Set up JSONL output writer (not used in markdown mode)
838 let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
839 let mut in_place_temp_path: Option<PathBuf> = None;
840 if !args.markdown
841 && let Some(output_path) = output.as_ref()
842 {
843 let write_path = if args.in_place {
844 let temp = output_path.with_extension("jsonl.tmp");
845 in_place_temp_path = Some(temp.clone());
846 temp
847 } else {
848 output_path.clone()
849 };
850
851 let file = OpenOptions::new()
852 .create(true)
853 .write(true)
854 .truncate(args.in_place)
855 .append(!args.in_place)
856 .open(&write_path)
857 .expect("Failed to open output file");
858
859 let mut writer = BufWriter::new(file);
860 let (sender, mut receiver) = mpsc::unbounded::<String>();
861 cx.background_spawn(async move {
862 while let Some(line) = receiver.next().await {
863 writeln!(writer, "{}", line).expect("Failed to write example");
864 writer.flush().expect("Failed to flush output");
865 }
866 })
867 .detach();
868 output_sender = Some(sender);
869 }
870
871 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
872 let finished_examples = Mutex::new(Vec::new());
873
874 let mut tasks = Vec::new();
875 for _ in 0..args.max_parallelism {
876 tasks.push(async {
877 loop {
878 let Some(mut repo_examples) =
879 grouped_examples.lock().unwrap().pop_front()
880 else {
881 break;
882 };
883 for example in &mut repo_examples {
884 let example_progress =
885 Progress::global().start_group(&example.spec.name);
886
887 let result = async {
888 match &command {
889 Command::Read => {}
890 Command::LoadProject => {
891 run_load_project(
892 example,
893 app_state.clone(),
894 &example_progress,
895 cx.clone(),
896 )
897 .await?;
898 }
899 Command::Context => {
900 run_context_retrieval(
901 example,
902 app_state.clone(),
903 &example_progress,
904 cx.clone(),
905 )
906 .await?;
907 }
908 Command::FormatPrompt(args) => {
909 run_format_prompt(
910 example,
911 args,
912 app_state.clone(),
913 &example_progress,
914 cx.clone(),
915 )
916 .await?;
917 }
918 Command::Predict(args) => {
919 run_prediction(
920 example,
921 args,
922 app_state.clone(),
923 &example_progress,
924 cx.clone(),
925 )
926 .await?;
927 }
928 Command::ParseOutput => {
929 parse_output::run_parse_output(example)?;
930 }
931 Command::Distill => {
932 run_distill(example).await?;
933 }
934 Command::Score(args) => {
935 run_scoring(
936 example,
937 args,
938 app_state.clone(),
939 &example_progress,
940 cx.clone(),
941 )
942 .await?;
943 }
944 Command::Eval(args) => {
945 run_scoring(
946 example,
947 &args.predict,
948 app_state.clone(),
949 &example_progress,
950 cx.clone(),
951 )
952 .await?;
953 }
954 Command::Clean
955 | Command::Synthesize(_)
956 | Command::SplitCommit(_)
957 | Command::Split(_)
958 | Command::FilterLanguages(_)
959 | Command::ImportBatch(_)
960 | Command::Qa(_)
961 | Command::Repair(_) => {
962 unreachable!()
963 }
964 }
965 anyhow::Ok(())
966 }
967 .await;
968
969 let failed = if let Err(error) = result {
970 handle_error(
971 error,
972 &args,
973 &command,
974 &app_state,
975 failfast_on_single_example,
976 &example,
977 )
978 .await;
979 true
980 } else {
981 false
982 };
983
984 let should_write = !failed || args.failed == FailedHandling::Keep;
985 if should_write {
986 if args.markdown {
987 let markdown_dir =
988 output.as_ref().expect("--markdown requires -o");
989 let filename = format!("{}.md", example.spec.filename());
990 let path = markdown_dir.join(&filename);
991 let markdown = example.spec.to_markdown();
992 std::fs::write(&path, &markdown)
993 .expect("Failed to write markdown file");
994 } else if let Some(ref mut sender) = output_sender.clone() {
995 let line = serde_json::to_string(&example).unwrap();
996 sender
997 .send(line)
998 .await
999 .expect("Failed to send to output writer");
1000 } else if args.output.is_none()
1001 && !matches!(command, Command::Eval(_))
1002 {
1003 let line = serde_json::to_string(&example).unwrap();
1004 println!("{}", line);
1005 }
1006 }
1007 }
1008
1009 let repo_url = &repo_examples.first().unwrap().spec.repository_url;
1010 let project = repo_examples
1011 .iter()
1012 .find_map(|e| e.state.as_ref().map(|s| s.project.clone()))
1013 .or_else(|| app_state.project_cache.get(repo_url));
1014
1015 if let Some(project) = project {
1016 let mut cx = cx.clone();
1017
1018 let shutdown_task: Task<()> =
1019 project.update(&mut cx, |project, cx| {
1020 let lsp_store = project.lsp_store();
1021 lsp_store.update(cx, |lsp_store, cx| {
1022 lsp_store.shutdown_all_language_servers(cx)
1023 })
1024 });
1025
1026 shutdown_task.await;
1027
1028 if let Some(ep_store) =
1029 cx.update(|cx| EditPredictionStore::try_global(cx))
1030 {
1031 ep_store.update(&mut cx, |store, _| {
1032 store.remove_project(&project);
1033 });
1034 }
1035 }
1036
1037 app_state.project_cache.remove(repo_url);
1038 for example in &mut repo_examples {
1039 example.state.take();
1040 }
1041 finished_examples
1042 .lock()
1043 .unwrap()
1044 .extend_from_slice(&repo_examples);
1045 }
1046 });
1047 }
1048 futures::future::join_all(tasks).await;
1049
1050 Progress::global().finalize();
1051
1052 match &command {
1053 Command::Predict(args) | Command::Score(args) => {
1054 predict::sync_batches(args.provider.as_ref()).await?;
1055 }
1056 Command::Eval(args) => {
1057 predict::sync_batches(args.predict.provider.as_ref()).await?;
1058 }
1059 _ => (),
1060 }
1061
1062 match &command {
1063 Command::Eval(args) => {
1064 let examples = finished_examples.lock().unwrap();
1065 score::print_report(&examples);
1066 if let Some(summary_path) = &args.summary_json {
1067 score::write_summary_json(&examples, summary_path)?;
1068 }
1069 }
1070 _ => (),
1071 };
1072
1073 // For --in-place, atomically rename temp file to original
1074 if let Some(temp_path) = &in_place_temp_path {
1075 let final_path = output.as_ref().expect("in_place_temp_path requires output");
1076 std::fs::rename(temp_path, final_path)
1077 .expect("Failed to rename temp file to final output");
1078 }
1079
1080 anyhow::Ok(())
1081 }
1082 .await;
1083
1084 if let Err(e) = result {
1085 panic!("Fatal error: {:?}", e);
1086 }
1087
1088 let _ = cx.update(|cx| cx.quit());
1089 })
1090 .detach();
1091 });
1092}
1093
1094async fn handle_error(
1095 error: anyhow::Error,
1096 args: &EpArgs,
1097 command: &Command,
1098 app_state: &Arc<headless::EpAppState>,
1099 failfast_on_single_example: bool,
1100 example: &Example,
1101) {
1102 Progress::global().increment_failed();
1103
1104 let msg;
1105 if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1106 let example_name = example.spec.filename();
1107
1108 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1109 app_state
1110 .fs
1111 .write(
1112 &failed_example_path,
1113 &serde_json::to_vec_pretty(&example).unwrap(),
1114 )
1115 .await
1116 .unwrap();
1117 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1118 app_state
1119 .fs
1120 .write(&err_path, format!("{error:?}").as_bytes())
1121 .await
1122 .unwrap();
1123
1124 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1125 let mut file = OpenOptions::new()
1126 .create(true)
1127 .append(true)
1128 .open(&failed_jsonl_path)
1129 .expect("Failed to open failed.jsonl");
1130 writeln!(file, "{}", serde_json::to_string(example).unwrap())
1131 .expect("Failed to write to failed.jsonl");
1132
1133 let cursor_path = example
1134 .repo_name()
1135 .unwrap()
1136 .worktree_path()
1137 .join(&example.spec.cursor_path);
1138 msg = format!(
1139 indoc::indoc! {"
1140 While processing \"{}\":
1141
1142 \x1b[31m{:?}\x1b[0m
1143
1144 Example: \x1b[36m{}\x1b[0m
1145 Error file: \x1b[36m{}\x1b[0m
1146 Cursor file: \x1b[36m{}\x1b[0m
1147 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1148 "},
1149 example.spec.name,
1150 error,
1151 failed_example_path.display(),
1152 err_path.display(),
1153 cursor_path.display(),
1154 command,
1155 failed_example_path.display(),
1156 );
1157 } else {
1158 msg = format!(
1159 indoc::indoc! {"
1160 While processing \"{}\":
1161
1162 \x1b[31m{:?}\x1b[0m
1163 "},
1164 example.spec.name, error
1165 );
1166 }
1167
1168 if args.failfast || failfast_on_single_example {
1169 Progress::global().finalize();
1170 panic!("{}", msg);
1171 } else {
1172 log::error!("{}", msg);
1173 }
1174}