1mod anthropic_client;
2mod distill;
3mod example;
4mod filter_languages;
5mod format_prompt;
6mod git;
7mod headless;
8mod load_project;
9mod metrics;
10mod openai_client;
11mod parse_output;
12mod paths;
13mod predict;
14mod progress;
15mod prompt_assets;
16mod pull_examples;
17mod qa;
18mod reorder_patch;
19mod repair;
20mod retrieve_context;
21mod reversal_tracking;
22mod score;
23mod split_commit;
24mod split_dataset;
25mod synthesize;
26mod word_diff;
27use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
28use collections::HashSet;
29use edit_prediction::EditPredictionStore;
30use futures::channel::mpsc;
31use futures::{SinkExt as _, StreamExt as _};
32use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
33use zeta_prompt::ZetaVersion;
34
35use reqwest_client::ReqwestClient;
36use serde::{Deserialize, Deserializer, Serialize, Serializer};
37use std::fmt::Display;
38use std::fs::{File, OpenOptions};
39use std::hash::{Hash, Hasher};
40use std::io::{BufRead, BufReader, BufWriter, Write};
41use std::sync::Mutex;
42use std::{path::PathBuf, sync::Arc};
43
44use crate::distill::run_distill;
45use crate::example::{Example, group_examples_by_repo, read_example_files};
46use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
47use crate::format_prompt::run_format_prompt;
48use crate::load_project::run_load_project;
49use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
50use crate::predict::run_prediction;
51use crate::progress::Progress;
52use crate::retrieve_context::run_context_retrieval;
53use crate::score::run_scoring;
54use crate::split_commit::SplitCommitArgs;
55use crate::split_dataset::SplitArgs;
56use crate::synthesize::{SynthesizeConfig, run_synthesize};
57
58#[derive(Parser, Debug)]
59#[command(name = "ep")]
60struct EpArgs {
61 #[arg(long, default_value_t = false)]
62 printenv: bool,
63 #[clap(long, default_value_t = 10, global = true)]
64 max_parallelism: usize,
65 /// The limit for the number of examples to process
66 /// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
67 #[clap(long, global = true)]
68 limit: Option<usize>,
69 #[clap(long, global = true)]
70 offset: Option<usize>,
71 /// Filter examples by name
72 #[clap(long, global = true)]
73 name: Option<String>,
74 /// Filter examples by repository
75 #[clap(long, global = true)]
76 repo: Option<String>,
77 #[command(subcommand)]
78 command: Option<Command>,
79 #[clap(global = true, help = INPUTS_HELP)]
80 inputs: Vec<PathBuf>,
81 #[arg(long, short, global = true)]
82 output: Option<PathBuf>,
83 #[arg(long, short, global = true)]
84 in_place: bool,
85 #[arg(long, short, global = true)]
86 failfast: bool,
87 /// How to handle failed examples in output: keep them or skip them.
88 /// Failed examples are always logged to the run's failed directory.
89 #[arg(long, global = true, default_value = "keep")]
90 failed: FailedHandling,
91 /// Output as markdown files instead of JSONL. When set, -o specifies a directory
92 /// where one .md file per example will be written (named after each example).
93 #[arg(long, short, global = true)]
94 markdown: bool,
95}
96
97/// Controls whether failed examples are included in the main output.
98/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
99#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
100pub enum FailedHandling {
101 /// Include failed examples in the main output (default)
102 #[default]
103 Keep,
104 /// Exclude failed examples from the main output
105 Skip,
106 /// Skip writing files
107 SkipNoFiles,
108}
109
110const INPUTS_HELP: &str = r#"
111Inputs can be file paths or special specifiers:
112
113 path
114 Path to an example(s) file (.md, .json, or .jsonl)
115
116 captured-after:{timestamp}
117 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
118 These are examples captured via the "Capture Edit Prediction Example" action.
119
120 rejected-after:{timestamp}
121 Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
122 These are predictions that were shown to users but rejected (useful for DPO training).
123
124 rated-after:{timestamp}
125 Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
126 These are predictions that users explicitly rated as positive or negative via the
127 rate completions modal. Only zeta2 predictions are included.
128 - Positive ratings: output becomes expected_patches
129 - Negative ratings: output becomes rejected_patch
130
131 rated-positive-after:{timestamp}
132 Same as rated-after, but only fetches positively rated predictions.
133
134 rated-negative-after:{timestamp}
135 Same as rated-after, but only fetches negatively rated predictions.
136
137 Required environment variables to connect to Snowflake:
138 EP_SNOWFLAKE_API_KEY
139 EP_SNOWFLAKE_BASE_URL
140
141 Optional:
142 EP_SNOWFLAKE_ROLE
143
144Examples:
145
146 # Read examples from a file
147 ep read examples.jsonl -o output.jsonl
148
149 # Read captured examples after a timestamp
150 ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
151
152 # Read rejected predictions for DPO training
153 ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
154
155 # Read user-rated predictions
156 ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
157
158 # Read only positively rated predictions
159 ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
160
161 # Read only negatively rated predictions
162 ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
163
164 # Mix multiple input sources
165 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
166"#;
167
168#[derive(Subcommand, Debug, Clone)]
169enum Command {
170 /// Read examples from files or fetch from Snowflake, output as .jsonl
171 Read,
172 /// Create git worktrees for each example and load file contents
173 LoadProject,
174 /// Retrieve context for input examples.
175 Context,
176 /// Generate a prompt string for a specific model
177 FormatPrompt(FormatPromptArgs),
178 /// Runs edit prediction
179 Predict(PredictArgs),
180 /// Parse model outputs (actual_output) into unified diffs (actual_patch).
181 /// Requires format-prompt to have been run first. Uses provider from prompt.
182 ParseOutput,
183 /// Computes a score based on actual and expected patches
184 Score(PredictArgs),
185 /// Prepares a distillation dataset by copying expected outputs to
186 /// predicted outputs and removing actual outputs and prompts.
187 Distill,
188 /// Print aggregated scores
189 Eval(EvalArgs),
190 /// Generate eval examples by analyzing git commits from a repository
191 Synthesize(SynthesizeArgs),
192 /// Remove git repositories and worktrees
193 Clean,
194 /// Generate an evaluation example by splitting a chronologically-ordered commit
195 SplitCommit(SplitCommitArgs),
196 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
197 Split(SplitArgs),
198 /// Filter a JSONL dataset by programming language (based on cursor_path extension)
199 FilterLanguages(FilterLanguagesArgs),
200 /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
201 ImportBatch(ImportBatchArgs),
202 /// Assess the quality of predictions using LLM-as-a-judge
203 Qa(qa::QaArgs),
204 /// Repair predictions that received poor QA scores by generating improved predictions
205 Repair(repair::RepairArgs),
206}
207
208impl Display for Command {
209 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210 match self {
211 Command::Read => write!(f, "read"),
212 Command::LoadProject => write!(f, "load-project"),
213 Command::Context => write!(f, "context"),
214 Command::FormatPrompt(args) => {
215 write!(f, "format-prompt --provider={}", args.provider)
216 }
217 Command::Predict(args) => match &args.provider {
218 Some(provider) => write!(f, "predict --provider={}", provider),
219 None => write!(f, "predict"),
220 },
221 Command::ParseOutput => write!(f, "parse-output"),
222 Command::Score(args) => match &args.provider {
223 Some(provider) => write!(f, "score --provider={}", provider),
224 None => write!(f, "score"),
225 },
226 Command::Distill => write!(f, "distill"),
227 Command::Eval(args) => match &args.predict.provider {
228 Some(provider) => write!(f, "eval --provider={}", provider),
229 None => write!(f, "eval"),
230 },
231 Command::Synthesize(args) => {
232 write!(f, "synthesize --repos {}", args.repos.join(" "))
233 }
234 Command::Clean => write!(f, "clean"),
235 Command::SplitCommit(_) => write!(f, "split-commit"),
236 Command::Split(_) => write!(f, "split"),
237 Command::FilterLanguages(_) => write!(f, "filter-languages"),
238 Command::ImportBatch(args) => {
239 write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
240 }
241 Command::Qa(_) => {
242 write!(f, "qa")
243 }
244 Command::Repair(_) => {
245 write!(f, "repair")
246 }
247 }
248 }
249}
250
251#[derive(Debug, Args, Clone)]
252struct FormatPromptArgs {
253 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
254 provider: PredictionProvider,
255}
256
257#[derive(Debug, Args, Clone)]
258struct PredictArgs {
259 #[clap(long, short('p'))]
260 provider: Option<PredictionProvider>,
261 #[clap(long, default_value_t = 1)]
262 repetitions: usize,
263 /// Only use cached responses, don't queue new requests for batching
264 #[clap(long)]
265 cache_only: bool,
266}
267
268#[derive(Debug, Args, Clone)]
269struct EvalArgs {
270 #[clap(flatten)]
271 predict: PredictArgs,
272 /// Path to write summary scores as JSON
273 #[clap(long)]
274 summary_json: Option<PathBuf>,
275}
276
277#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
278pub enum TeacherBackend {
279 Sonnet45,
280 Gpt52,
281}
282
283impl std::fmt::Display for TeacherBackend {
284 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
285 match self {
286 TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
287 TeacherBackend::Gpt52 => write!(f, "gpt52"),
288 }
289 }
290}
291
292impl std::str::FromStr for TeacherBackend {
293 type Err = anyhow::Error;
294
295 fn from_str(s: &str) -> Result<Self, Self::Err> {
296 match s.to_lowercase().as_str() {
297 "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
298 "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
299 "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
300 _ => anyhow::bail!("unknown teacher backend `{s}`. Valid options: sonnet45, gpt52"),
301 }
302 }
303}
304
305impl TeacherBackend {
306 pub fn model_name(&self) -> &'static str {
307 match self {
308 TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
309 TeacherBackend::Gpt52 => "gpt-5.2",
310 }
311 }
312}
313
314#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
315enum PredictionProvider {
316 Sweep,
317 Mercury,
318 Zeta1,
319 Zeta2(ZetaVersion),
320 Teacher(TeacherBackend),
321 TeacherNonBatching(TeacherBackend),
322 Repair,
323}
324
325impl Default for PredictionProvider {
326 fn default() -> Self {
327 PredictionProvider::Zeta2(ZetaVersion::default())
328 }
329}
330
331impl std::fmt::Display for PredictionProvider {
332 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
333 match self {
334 PredictionProvider::Sweep => write!(f, "sweep"),
335 PredictionProvider::Mercury => write!(f, "mercury"),
336 PredictionProvider::Zeta1 => write!(f, "zeta1"),
337 PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
338 PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
339 PredictionProvider::TeacherNonBatching(backend) => {
340 write!(f, "teacher-non-batching:{backend}")
341 }
342 PredictionProvider::Repair => write!(f, "repair"),
343 }
344 }
345}
346
347impl std::str::FromStr for PredictionProvider {
348 type Err = anyhow::Error;
349
350 fn from_str(s: &str) -> Result<Self, Self::Err> {
351 let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
352
353 let provider_lower = provider.to_lowercase();
354 match provider_lower.as_str() {
355 "sweep" => Ok(PredictionProvider::Sweep),
356 "mercury" => Ok(PredictionProvider::Mercury),
357 "zeta1" => Ok(PredictionProvider::Zeta1),
358 "zeta2" => {
359 let version = arg.map(ZetaVersion::parse).transpose()?.unwrap_or_default();
360 Ok(PredictionProvider::Zeta2(version))
361 }
362 "teacher" => {
363 let backend = arg
364 .map(|a| a.parse())
365 .transpose()?
366 .unwrap_or(TeacherBackend::Sonnet45);
367 Ok(PredictionProvider::Teacher(backend))
368 }
369 "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
370 let backend = arg
371 .map(|a| a.parse())
372 .transpose()?
373 .unwrap_or(TeacherBackend::Sonnet45);
374 Ok(PredictionProvider::TeacherNonBatching(backend))
375 }
376 "repair" => Ok(PredictionProvider::Repair),
377 _ => {
378 anyhow::bail!(
379 "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
380 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
381 For teacher, you can specify a backend like `teacher:sonnet45` or `teacher:gpt52`.\n\
382 Available zeta versions:\n{}",
383 ZetaVersion::options_as_string()
384 )
385 }
386 }
387 }
388}
389
390impl Serialize for PredictionProvider {
391 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
392 where
393 S: Serializer,
394 {
395 serializer.serialize_str(&self.to_string())
396 }
397}
398
399impl<'de> Deserialize<'de> for PredictionProvider {
400 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
401 where
402 D: Deserializer<'de>,
403 {
404 let s = String::deserialize(deserializer)?;
405 s.parse().map_err(serde::de::Error::custom)
406 }
407}
408
409#[derive(Debug, Args, Clone)]
410struct SynthesizeArgs {
411 /// Repository URLs (git@github.com:owner/repo or https://...)
412 #[clap(long, required = true, num_args = 1..)]
413 repos: Vec<String>,
414
415 /// Number of examples to generate per repository
416 #[clap(long, default_value_t = 5)]
417 count: usize,
418
419 /// Maximum commits to scan per repository before giving up
420 #[clap(long, default_value_t = 100)]
421 max_commits: usize,
422
423 /// Ignore state file and reprocess all commits
424 #[clap(long)]
425 fresh: bool,
426}
427
428#[derive(Debug, Args, Clone)]
429struct ImportBatchArgs {
430 /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
431 #[clap(long, required = true, num_args = 1..)]
432 batch_ids: Vec<String>,
433 /// Which provider's batches to import (anthropic or openai)
434 #[clap(long, default_value = "anthropic")]
435 provider: BatchProvider,
436}
437
438#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
439enum BatchProvider {
440 Anthropic,
441 Openai,
442}
443
444impl EpArgs {
445 fn output_path(&self) -> Option<PathBuf> {
446 if self.in_place {
447 if self.inputs.len() == 1 {
448 self.inputs.first().cloned()
449 } else {
450 panic!("--in-place requires exactly one input file")
451 }
452 } else {
453 self.output.clone()
454 }
455 }
456}
457
458async fn load_examples(
459 http_client: Arc<dyn http_client::HttpClient>,
460 args: &EpArgs,
461 output_path: Option<&PathBuf>,
462 background_executor: BackgroundExecutor,
463) -> anyhow::Result<Vec<Example>> {
464 let mut captured_after_timestamps = Vec::new();
465 let mut rejected_after_timestamps = Vec::new();
466 let mut requested_after_timestamps = Vec::new();
467 let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
468 Vec::new();
469 let mut file_inputs = Vec::new();
470
471 for input in &args.inputs {
472 let input_string = input.to_string_lossy();
473 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
474 captured_after_timestamps.push(timestamp.to_string());
475 } else if let Some(timestamp) =
476 pull_examples::parse_rejected_after_input(input_string.as_ref())
477 {
478 rejected_after_timestamps.push(timestamp.to_string());
479 } else if let Some(timestamp) =
480 pull_examples::parse_requested_after_input(input_string.as_ref())
481 {
482 requested_after_timestamps.push(timestamp.to_string());
483 } else if let Some((timestamp, rating_filter)) =
484 pull_examples::parse_rated_after_input(input_string.as_ref())
485 {
486 rated_after_inputs.push((timestamp.to_string(), rating_filter));
487 } else {
488 file_inputs.push(input.clone());
489 }
490 }
491
492 let mut examples = read_example_files(&file_inputs);
493
494 Progress::global().set_total_examples(examples.len());
495
496 let remaining_limit_for_snowflake =
497 args.limit.map(|limit| limit.saturating_sub(examples.len()));
498
499 if let Some(0) = remaining_limit_for_snowflake {
500 log::info!(
501 "skipping Snowflake inputs because --limit is already satisfied by example files"
502 );
503 } else {
504 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
505
506 if !captured_after_timestamps.is_empty() {
507 captured_after_timestamps.sort();
508
509 let mut captured_examples = pull_examples::fetch_captured_examples_after(
510 http_client.clone(),
511 &captured_after_timestamps,
512 max_rows_per_timestamp,
513 background_executor.clone(),
514 )
515 .await?;
516 examples.append(&mut captured_examples);
517 }
518
519 if !rejected_after_timestamps.is_empty() {
520 rejected_after_timestamps.sort();
521
522 let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
523 http_client.clone(),
524 &rejected_after_timestamps,
525 max_rows_per_timestamp,
526 background_executor.clone(),
527 )
528 .await?;
529 examples.append(&mut rejected_examples);
530 }
531
532 if !requested_after_timestamps.is_empty() {
533 requested_after_timestamps.sort();
534
535 let mut requested_examples = pull_examples::fetch_requested_examples_after(
536 http_client.clone(),
537 &requested_after_timestamps,
538 max_rows_per_timestamp,
539 background_executor.clone(),
540 )
541 .await?;
542 examples.append(&mut requested_examples);
543 }
544
545 if !rated_after_inputs.is_empty() {
546 rated_after_inputs.sort();
547
548 let mut rated_examples = pull_examples::fetch_rated_examples_after(
549 http_client,
550 &rated_after_inputs,
551 max_rows_per_timestamp,
552 background_executor,
553 )
554 .await?;
555 examples.append(&mut rated_examples);
556 }
557 }
558
559 crate::example::sort_examples_by_repo_and_rev(&mut examples);
560
561 if let Some(name_filter) = &args.name {
562 examples.retain(|example| example.spec.name.contains(name_filter));
563 }
564 if let Some(repo_filter) = &args.repo {
565 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
566 }
567
568 // Skip resume logic for --in-place since input and output are the same file,
569 // which would incorrectly treat all input examples as already processed.
570 if !args.in_place {
571 if let Some(path) = output_path {
572 resume_from_output(path, &mut examples);
573 }
574 }
575
576 if let Some(offset) = args.offset {
577 examples.splice(0..offset, []);
578 }
579
580 if let Some(limit) = args.limit {
581 examples.truncate(limit);
582 }
583
584 let progress = Progress::global();
585 progress.set_total_examples(examples.len());
586 progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
587
588 Ok(examples)
589}
590
591fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
592 let mut hasher = collections::FxHasher::default();
593 spec.hash(&mut hasher);
594 hasher.finish()
595}
596
597fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
598 let file = match File::open(path) {
599 Ok(f) => f,
600 Err(_) => return,
601 };
602
603 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
604
605 let reader = BufReader::new(file);
606 let mut kept_lines = Vec::new();
607 let mut kept_hashes = HashSet::default();
608
609 for line in reader.lines() {
610 let line = match line {
611 Ok(l) => l,
612 Err(_) => continue,
613 };
614
615 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
616 let hash = spec_hash(&output_example.spec);
617 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
618 kept_hashes.insert(hash);
619 kept_lines.push(line);
620 }
621 }
622 }
623
624 let total = examples.len();
625 let already_processed = kept_hashes.len();
626
627 eprintln!(
628 "Resuming: {}/{} examples already processed",
629 already_processed, total
630 );
631
632 let file = OpenOptions::new()
633 .write(true)
634 .truncate(true)
635 .open(path)
636 .expect("Failed to open output file for rewriting");
637 let mut writer = BufWriter::new(file);
638 for line in &kept_lines {
639 writeln!(writer, "{}", line).expect("Failed to write to output file");
640 }
641 writer.flush().expect("Failed to flush output file");
642
643 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
644}
645
646fn main() {
647 let args = EpArgs::parse();
648
649 if args.printenv {
650 ::util::shell_env::print_env();
651 return;
652 }
653
654 let output = args.output_path();
655
656 if args.markdown && output.is_none() {
657 eprintln!("--markdown requires -o to specify the output directory");
658 std::process::exit(1);
659 }
660
661 let command = match &args.command {
662 Some(cmd) => cmd.clone(),
663 None => {
664 EpArgs::command().print_help().unwrap();
665 return;
666 }
667 };
668
669 match &command {
670 Command::ImportBatch(import_args) => {
671 smol::block_on(async {
672 match import_args.provider {
673 BatchProvider::Anthropic => {
674 let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
675 .expect("Failed to create Anthropic client");
676 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
677 eprintln!("Error importing Anthropic batches: {:?}", e);
678 std::process::exit(1);
679 }
680 }
681 BatchProvider::Openai => {
682 let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
683 .expect("Failed to create OpenAI client");
684 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
685 eprintln!("Error importing OpenAI batches: {:?}", e);
686 std::process::exit(1);
687 }
688 }
689 }
690 println!(
691 "Successfully imported {} batch(es)",
692 import_args.batch_ids.len()
693 );
694 });
695 return;
696 }
697 Command::Clean => {
698 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
699 return;
700 }
701 Command::Synthesize(synth_args) => {
702 let Some(output_dir) = args.output else {
703 panic!("output dir is required");
704 };
705 let config = SynthesizeConfig {
706 repo_urls: synth_args.repos.clone(),
707 count: synth_args.count,
708 max_commits: synth_args.max_commits,
709 output_dir,
710 fresh: synth_args.fresh,
711 };
712 smol::block_on(async {
713 if let Err(e) = run_synthesize(config).await {
714 eprintln!("Error: {:?}", e);
715 std::process::exit(1);
716 }
717 });
718 return;
719 }
720 Command::SplitCommit(split_commit_args) => {
721 if let Err(error) = split_commit::run_split_commit(
722 split_commit_args,
723 &args.inputs,
724 output.as_ref(),
725 args.failed,
726 ) {
727 eprintln!("{error:#}");
728 std::process::exit(1);
729 }
730 return;
731 }
732 Command::Split(split_args) => {
733 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
734 eprintln!("{error:#}");
735 std::process::exit(1);
736 }
737 return;
738 }
739 Command::FilterLanguages(filter_args) => {
740 if let Err(error) =
741 run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
742 {
743 eprintln!("{error:#}");
744 std::process::exit(1);
745 }
746 return;
747 }
748 Command::Qa(qa_args) => {
749 // Read examples from input files
750 let mut examples = example::read_example_files(&args.inputs);
751
752 // Apply filters
753 if let Some(name_filter) = &args.name {
754 examples.retain(|e| e.spec.name.contains(name_filter));
755 }
756 if let Some(repo_filter) = &args.repo {
757 examples.retain(|e| e.spec.repository_url.contains(repo_filter));
758 }
759 if let Some(offset) = args.offset {
760 examples.splice(0..offset, []);
761 }
762 if let Some(limit) = args.limit {
763 examples.truncate(limit);
764 }
765
766 smol::block_on(async {
767 if let Err(e) = qa::run_qa(&mut examples, qa_args, output.as_ref()).await {
768 eprintln!("Error: {:?}", e);
769 std::process::exit(1);
770 }
771 });
772 return;
773 }
774 Command::Repair(repair_args) => {
775 // Read examples from input files
776 let mut examples = example::read_example_files(&args.inputs);
777
778 // Apply filters
779 if let Some(name_filter) = &args.name {
780 examples.retain(|e| e.spec.name.contains(name_filter));
781 }
782 if let Some(repo_filter) = &args.repo {
783 examples.retain(|e| e.spec.repository_url.contains(repo_filter));
784 }
785 if let Some(offset) = args.offset {
786 examples.splice(0..offset, []);
787 }
788 if let Some(limit) = args.limit {
789 examples.truncate(limit);
790 }
791
792 smol::block_on(async {
793 if let Err(e) =
794 repair::run_repair(&mut examples, repair_args, output.as_ref()).await
795 {
796 eprintln!("Error: {:?}", e);
797 std::process::exit(1);
798 }
799 });
800 return;
801 }
802 _ => {}
803 }
804
805 let http_client = Arc::new(ReqwestClient::new());
806 let app = Application::headless().with_http_client(http_client);
807
808 app.run(move |cx| {
809 let app_state = Arc::new(headless::init(cx));
810 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
811
812 cx.spawn(async move |cx| {
813 let result = async {
814 let examples = load_examples(
815 app_state.client.http_client(),
816 &args,
817 output.as_ref(),
818 cx.background_executor().clone(),
819 )
820 .await?;
821
822 match &command {
823 Command::Predict(args) | Command::Score(args) => {
824 predict::sync_batches(args.provider.as_ref()).await?;
825 }
826 Command::Eval(args) => {
827 predict::sync_batches(args.predict.provider.as_ref()).await?;
828 }
829 _ => (),
830 }
831
832 let failfast_on_single_example = examples.len() == 1;
833
834 // For --markdown mode, create the output directory if it doesn't exist
835 if args.markdown {
836 let dir = output.as_ref().expect("--markdown requires -o");
837 if !dir.exists() {
838 std::fs::create_dir_all(dir)
839 .expect("Failed to create markdown output directory");
840 }
841 }
842
843 // Set up JSONL output writer (not used in markdown mode)
844 let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
845 let mut in_place_temp_path: Option<PathBuf> = None;
846 if !args.markdown
847 && let Some(output_path) = output.as_ref()
848 {
849 let write_path = if args.in_place {
850 let temp = output_path.with_extension("jsonl.tmp");
851 in_place_temp_path = Some(temp.clone());
852 temp
853 } else {
854 output_path.clone()
855 };
856
857 let file = OpenOptions::new()
858 .create(true)
859 .write(true)
860 .truncate(args.in_place)
861 .append(!args.in_place)
862 .open(&write_path)
863 .expect("Failed to open output file");
864
865 let mut writer = BufWriter::new(file);
866 let (sender, mut receiver) = mpsc::unbounded::<String>();
867 cx.background_spawn(async move {
868 while let Some(line) = receiver.next().await {
869 writeln!(writer, "{}", line).expect("Failed to write example");
870 writer.flush().expect("Failed to flush output");
871 }
872 })
873 .detach();
874 output_sender = Some(sender);
875 }
876
877 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
878 let finished_examples = Mutex::new(Vec::new());
879
880 let mut tasks = Vec::new();
881 for _ in 0..args.max_parallelism {
882 tasks.push(async {
883 loop {
884 let Some(mut repo_examples) =
885 grouped_examples.lock().unwrap().pop_front()
886 else {
887 break;
888 };
889 for example in &mut repo_examples {
890 let example_progress =
891 Progress::global().start_group(&example.spec.name);
892
893 let result = async {
894 match &command {
895 Command::Read => {}
896 Command::LoadProject => {
897 run_load_project(
898 example,
899 app_state.clone(),
900 &example_progress,
901 cx.clone(),
902 )
903 .await?;
904 }
905 Command::Context => {
906 run_context_retrieval(
907 example,
908 app_state.clone(),
909 &example_progress,
910 cx.clone(),
911 )
912 .await?;
913 }
914 Command::FormatPrompt(args) => {
915 run_format_prompt(
916 example,
917 args,
918 app_state.clone(),
919 &example_progress,
920 cx.clone(),
921 )
922 .await?;
923 }
924 Command::Predict(args) => {
925 run_prediction(
926 example,
927 args,
928 app_state.clone(),
929 &example_progress,
930 cx.clone(),
931 )
932 .await?;
933 }
934 Command::ParseOutput => {
935 parse_output::run_parse_output(example)?;
936 }
937 Command::Distill => {
938 run_distill(example).await?;
939 }
940 Command::Score(args) => {
941 run_scoring(
942 example,
943 args,
944 app_state.clone(),
945 &example_progress,
946 cx.clone(),
947 )
948 .await?;
949 }
950 Command::Eval(args) => {
951 run_scoring(
952 example,
953 &args.predict,
954 app_state.clone(),
955 &example_progress,
956 cx.clone(),
957 )
958 .await?;
959 }
960 Command::Clean
961 | Command::Synthesize(_)
962 | Command::SplitCommit(_)
963 | Command::Split(_)
964 | Command::FilterLanguages(_)
965 | Command::ImportBatch(_)
966 | Command::Qa(_)
967 | Command::Repair(_) => {
968 unreachable!()
969 }
970 }
971 anyhow::Ok(())
972 }
973 .await;
974
975 let failed = if let Err(error) = result {
976 handle_error(
977 error,
978 &args,
979 &command,
980 &app_state,
981 failfast_on_single_example,
982 &example,
983 )
984 .await;
985 true
986 } else {
987 false
988 };
989
990 let should_write = !failed || args.failed == FailedHandling::Keep;
991 if should_write {
992 if args.markdown {
993 let markdown_dir =
994 output.as_ref().expect("--markdown requires -o");
995 let filename = format!("{}.md", example.spec.filename());
996 let path = markdown_dir.join(&filename);
997 let markdown = example.spec.to_markdown();
998 std::fs::write(&path, &markdown)
999 .expect("Failed to write markdown file");
1000 } else if let Some(ref mut sender) = output_sender.clone() {
1001 let line = serde_json::to_string(&example).unwrap();
1002 sender
1003 .send(line)
1004 .await
1005 .expect("Failed to send to output writer");
1006 } else if args.output.is_none()
1007 && !matches!(command, Command::Eval(_))
1008 {
1009 let line = serde_json::to_string(&example).unwrap();
1010 println!("{}", line);
1011 }
1012 }
1013 }
1014
1015 let repo_url = &repo_examples.first().unwrap().spec.repository_url;
1016 let project = repo_examples
1017 .iter()
1018 .find_map(|e| e.state.as_ref().map(|s| s.project.clone()))
1019 .or_else(|| app_state.project_cache.get(repo_url));
1020
1021 if let Some(project) = project {
1022 let mut cx = cx.clone();
1023
1024 let shutdown_task: Task<()> =
1025 project.update(&mut cx, |project, cx| {
1026 let lsp_store = project.lsp_store();
1027 lsp_store.update(cx, |lsp_store, cx| {
1028 lsp_store.shutdown_all_language_servers(cx)
1029 })
1030 });
1031
1032 shutdown_task.await;
1033
1034 if let Some(ep_store) =
1035 cx.update(|cx| EditPredictionStore::try_global(cx))
1036 {
1037 ep_store.update(&mut cx, |store, _| {
1038 store.remove_project(&project);
1039 });
1040 }
1041 }
1042
1043 app_state.project_cache.remove(repo_url);
1044 for example in &mut repo_examples {
1045 example.state.take();
1046 }
1047 finished_examples
1048 .lock()
1049 .unwrap()
1050 .extend_from_slice(&repo_examples);
1051 }
1052 });
1053 }
1054 futures::future::join_all(tasks).await;
1055
1056 Progress::global().finalize();
1057
1058 match &command {
1059 Command::Predict(args) | Command::Score(args) => {
1060 predict::sync_batches(args.provider.as_ref()).await?;
1061 }
1062 Command::Eval(args) => {
1063 predict::sync_batches(args.predict.provider.as_ref()).await?;
1064 }
1065 _ => (),
1066 }
1067
1068 match &command {
1069 Command::Eval(args) => {
1070 let examples = finished_examples.lock().unwrap();
1071 score::print_report(&examples);
1072 if let Some(summary_path) = &args.summary_json {
1073 score::write_summary_json(&examples, summary_path)?;
1074 }
1075 }
1076 _ => (),
1077 };
1078
1079 // For --in-place, atomically rename temp file to original
1080 if let Some(temp_path) = &in_place_temp_path {
1081 let final_path = output.as_ref().expect("in_place_temp_path requires output");
1082 std::fs::rename(temp_path, final_path)
1083 .expect("Failed to rename temp file to final output");
1084 }
1085
1086 anyhow::Ok(())
1087 }
1088 .await;
1089
1090 if let Err(e) = result {
1091 panic!("Fatal error: {:?}", e);
1092 }
1093
1094 let _ = cx.update(|cx| cx.quit());
1095 })
1096 .detach();
1097 });
1098}
1099
1100async fn handle_error(
1101 error: anyhow::Error,
1102 args: &EpArgs,
1103 command: &Command,
1104 app_state: &Arc<headless::EpAppState>,
1105 failfast_on_single_example: bool,
1106 example: &Example,
1107) {
1108 Progress::global().increment_failed();
1109
1110 let msg;
1111 if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1112 let example_name = example.spec.filename();
1113
1114 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1115 app_state
1116 .fs
1117 .write(
1118 &failed_example_path,
1119 &serde_json::to_vec_pretty(&example).unwrap(),
1120 )
1121 .await
1122 .unwrap();
1123 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1124 app_state
1125 .fs
1126 .write(&err_path, format!("{error:?}").as_bytes())
1127 .await
1128 .unwrap();
1129
1130 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1131 let mut file = OpenOptions::new()
1132 .create(true)
1133 .append(true)
1134 .open(&failed_jsonl_path)
1135 .expect("Failed to open failed.jsonl");
1136 writeln!(file, "{}", serde_json::to_string(example).unwrap())
1137 .expect("Failed to write to failed.jsonl");
1138
1139 let cursor_path = match example.repo_name() {
1140 Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1141 Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1142 };
1143 msg = format!(
1144 indoc::indoc! {"
1145 While processing \"{}\":
1146
1147 \x1b[31m{:?}\x1b[0m
1148
1149 Example: \x1b[36m{}\x1b[0m
1150 Error file: \x1b[36m{}\x1b[0m
1151 Cursor file: \x1b[36m{}\x1b[0m
1152 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1153 "},
1154 example.spec.name,
1155 error,
1156 failed_example_path.display(),
1157 err_path.display(),
1158 cursor_path.display(),
1159 command,
1160 failed_example_path.display(),
1161 );
1162 } else {
1163 msg = format!(
1164 indoc::indoc! {"
1165 While processing \"{}\":
1166
1167 \x1b[31m{:?}\x1b[0m
1168 "},
1169 example.spec.name, error
1170 );
1171 }
1172
1173 if args.failfast || failfast_on_single_example {
1174 Progress::global().finalize();
1175 panic!("{}", msg);
1176 } else {
1177 log::error!("{}", msg);
1178 }
1179}