1mod anthropic_client;
2mod distill;
3mod example;
4mod filter_languages;
5mod format_prompt;
6mod git;
7mod headless;
8mod load_project;
9mod metrics;
10mod openai_client;
11mod parse_output;
12mod paths;
13mod predict;
14mod progress;
15mod prompt_assets;
16mod pull_examples;
17mod qa;
18mod reorder_patch;
19mod repair;
20mod retrieve_context;
21mod reversal_tracking;
22mod score;
23mod split_commit;
24mod split_dataset;
25mod synthesize;
26mod word_diff;
27use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
28use collections::HashSet;
29use edit_prediction::EditPredictionStore;
30use futures::channel::mpsc;
31use futures::{SinkExt as _, StreamExt as _};
32use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
33use zeta_prompt::ZetaVersion;
34
35use reqwest_client::ReqwestClient;
36use serde::{Deserialize, Deserializer, Serialize, Serializer};
37use std::fmt::Display;
38use std::fs::{File, OpenOptions};
39use std::hash::{Hash, Hasher};
40use std::io::{BufRead, BufReader, BufWriter, Write};
41use std::sync::Mutex;
42use std::{path::PathBuf, sync::Arc};
43
44use crate::distill::run_distill;
45use crate::example::{Example, group_examples_by_repo, read_example_files};
46use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
47use crate::format_prompt::run_format_prompt;
48use crate::load_project::run_load_project;
49use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
50use crate::predict::run_prediction;
51use crate::progress::Progress;
52use crate::retrieve_context::run_context_retrieval;
53use crate::score::run_scoring;
54use crate::split_commit::SplitCommitArgs;
55use crate::split_dataset::SplitArgs;
56use crate::synthesize::{SynthesizeConfig, run_synthesize};
57
58#[derive(Parser, Debug)]
59#[command(name = "ep")]
60struct EpArgs {
61 #[arg(long, default_value_t = false)]
62 printenv: bool,
63 #[clap(long, default_value_t = 10, global = true)]
64 max_parallelism: usize,
65 /// The limit for the number of examples to process
66 /// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
67 #[clap(long, global = true)]
68 limit: Option<usize>,
69 #[clap(long, global = true)]
70 offset: Option<usize>,
71 /// Filter examples by name
72 #[clap(long, global = true)]
73 name: Option<String>,
74 /// Filter examples by repository
75 #[clap(long, global = true)]
76 repo: Option<String>,
77 #[command(subcommand)]
78 command: Option<Command>,
79 #[clap(global = true, help = INPUTS_HELP)]
80 inputs: Vec<PathBuf>,
81 #[arg(long, short, global = true)]
82 output: Option<PathBuf>,
83 #[arg(long, short, global = true)]
84 in_place: bool,
85 #[arg(long, short, global = true)]
86 failfast: bool,
87 /// How to handle failed examples in output: keep them or skip them.
88 /// Failed examples are always logged to the run's failed directory.
89 #[arg(long, global = true, default_value = "keep")]
90 failed: FailedHandling,
91 /// Output as markdown files instead of JSONL. When set, -o specifies a directory
92 /// where one .md file per example will be written (named after each example).
93 #[arg(long, short, global = true)]
94 markdown: bool,
95}
96
97/// Controls whether failed examples are included in the main output.
98/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
99#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
100pub enum FailedHandling {
101 /// Include failed examples in the main output (default)
102 #[default]
103 Keep,
104 /// Exclude failed examples from the main output
105 Skip,
106 /// Skip writing files
107 SkipNoFiles,
108}
109
110const INPUTS_HELP: &str = r#"
111Inputs can be file paths or special specifiers:
112
113 path
114 Path to an example(s) file (.md, .json, or .jsonl)
115
116 captured-after:{timestamp}
117 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
118 These are examples captured via the "Capture Edit Prediction Example" action.
119
120 rejected-after:{timestamp}
121 Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
122 These are predictions that were shown to users but rejected (useful for DPO training).
123
124 rated-after:{timestamp}
125 Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
126 These are predictions that users explicitly rated as positive or negative via the
127 rate completions modal. Only zeta2 predictions are included.
128 - Positive ratings: output becomes expected_patches
129 - Negative ratings: output becomes rejected_patch
130
131 rated-positive-after:{timestamp}
132 Same as rated-after, but only fetches positively rated predictions.
133
134 rated-negative-after:{timestamp}
135 Same as rated-after, but only fetches negatively rated predictions.
136
137 Required environment variables to connect to Snowflake:
138 EP_SNOWFLAKE_API_KEY
139 EP_SNOWFLAKE_BASE_URL
140
141 Optional:
142 EP_SNOWFLAKE_ROLE
143
144Examples:
145
146 # Read examples from a file
147 ep read examples.jsonl -o output.jsonl
148
149 # Read captured examples after a timestamp
150 ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
151
152 # Read rejected predictions for DPO training
153 ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
154
155 # Read user-rated predictions
156 ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
157
158 # Read only positively rated predictions
159 ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
160
161 # Read only negatively rated predictions
162 ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
163
164 # Mix multiple input sources
165 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
166"#;
167
168#[derive(Subcommand, Debug, Clone)]
169enum Command {
170 /// Read examples from files or fetch from Snowflake, output as .jsonl
171 Read,
172 /// Create git worktrees for each example and load file contents
173 LoadProject,
174 /// Retrieve context for input examples.
175 Context,
176 /// Generate a prompt string for a specific model
177 FormatPrompt(FormatPromptArgs),
178 /// Runs edit prediction
179 Predict(PredictArgs),
180 /// Parse model outputs (actual_output) into unified diffs (actual_patch).
181 /// Requires format-prompt to have been run first. Uses provider from prompt.
182 ParseOutput,
183 /// Computes a score based on actual and expected patches
184 Score(PredictArgs),
185 /// Prepares a distillation dataset by copying expected outputs to
186 /// predicted outputs and removing actual outputs and prompts.
187 Distill,
188 /// Print aggregated scores
189 Eval(EvalArgs),
190 /// Generate eval examples by analyzing git commits from a repository
191 Synthesize(SynthesizeArgs),
192 /// Remove git repositories and worktrees
193 Clean,
194 /// Generate an evaluation example by splitting a chronologically-ordered commit
195 SplitCommit(SplitCommitArgs),
196 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
197 Split(SplitArgs),
198 /// Filter a JSONL dataset by programming language (based on cursor_path extension)
199 FilterLanguages(FilterLanguagesArgs),
200 /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
201 ImportBatch(ImportBatchArgs),
202 /// Assess the quality of predictions using LLM-as-a-judge
203 Qa(qa::QaArgs),
204 /// Repair predictions that received poor QA scores by generating improved predictions
205 Repair(repair::RepairArgs),
206}
207
208impl Display for Command {
209 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210 match self {
211 Command::Read => write!(f, "read"),
212 Command::LoadProject => write!(f, "load-project"),
213 Command::Context => write!(f, "context"),
214 Command::FormatPrompt(args) => {
215 write!(f, "format-prompt --provider={}", args.provider)
216 }
217 Command::Predict(args) => match &args.provider {
218 Some(provider) => write!(f, "predict --provider={}", provider),
219 None => write!(f, "predict"),
220 },
221 Command::ParseOutput => write!(f, "parse-output"),
222 Command::Score(args) => match &args.provider {
223 Some(provider) => write!(f, "score --provider={}", provider),
224 None => write!(f, "score"),
225 },
226 Command::Distill => write!(f, "distill"),
227 Command::Eval(args) => match &args.predict.provider {
228 Some(provider) => write!(f, "eval --provider={}", provider),
229 None => write!(f, "eval"),
230 },
231 Command::Synthesize(args) => {
232 write!(f, "synthesize --repos {}", args.repos.join(" "))
233 }
234 Command::Clean => write!(f, "clean"),
235 Command::SplitCommit(_) => write!(f, "split-commit"),
236 Command::Split(_) => write!(f, "split"),
237 Command::FilterLanguages(_) => write!(f, "filter-languages"),
238 Command::ImportBatch(args) => {
239 write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
240 }
241 Command::Qa(_) => {
242 write!(f, "qa")
243 }
244 Command::Repair(_) => {
245 write!(f, "repair")
246 }
247 }
248 }
249}
250
251#[derive(Debug, Args, Clone)]
252struct FormatPromptArgs {
253 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
254 provider: PredictionProvider,
255}
256
257#[derive(Debug, Args, Clone)]
258struct PredictArgs {
259 #[clap(long, short('p'))]
260 provider: Option<PredictionProvider>,
261 #[clap(long, default_value_t = 1)]
262 repetitions: usize,
263 /// Only use cached responses, don't queue new requests for batching
264 #[clap(long)]
265 cache_only: bool,
266}
267
268#[derive(Debug, Args, Clone)]
269struct EvalArgs {
270 #[clap(flatten)]
271 predict: PredictArgs,
272 /// Path to write summary scores as JSON
273 #[clap(long)]
274 summary_json: Option<PathBuf>,
275}
276
277#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
278pub enum TeacherBackend {
279 Sonnet45,
280 Gpt52,
281}
282
283impl std::fmt::Display for TeacherBackend {
284 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
285 match self {
286 TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
287 TeacherBackend::Gpt52 => write!(f, "gpt52"),
288 }
289 }
290}
291
292impl std::str::FromStr for TeacherBackend {
293 type Err = anyhow::Error;
294
295 fn from_str(s: &str) -> Result<Self, Self::Err> {
296 match s.to_lowercase().as_str() {
297 "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
298 "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
299 "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
300 _ => anyhow::bail!("unknown teacher backend `{s}`. Valid options: sonnet45, gpt52"),
301 }
302 }
303}
304
305impl TeacherBackend {
306 pub fn model_name(&self) -> &'static str {
307 match self {
308 TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
309 TeacherBackend::Gpt52 => "gpt-5.2",
310 }
311 }
312}
313
314#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
315enum PredictionProvider {
316 Sweep,
317 Mercury,
318 Zeta1,
319 Zeta2(ZetaVersion),
320 Teacher(TeacherBackend),
321 TeacherNonBatching(TeacherBackend),
322 Repair,
323}
324
325impl Default for PredictionProvider {
326 fn default() -> Self {
327 PredictionProvider::Zeta2(ZetaVersion::default())
328 }
329}
330
331impl std::fmt::Display for PredictionProvider {
332 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
333 match self {
334 PredictionProvider::Sweep => write!(f, "sweep"),
335 PredictionProvider::Mercury => write!(f, "mercury"),
336 PredictionProvider::Zeta1 => write!(f, "zeta1"),
337 PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
338 PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
339 PredictionProvider::TeacherNonBatching(backend) => {
340 write!(f, "teacher-non-batching:{backend}")
341 }
342 PredictionProvider::Repair => write!(f, "repair"),
343 }
344 }
345}
346
347impl std::str::FromStr for PredictionProvider {
348 type Err = anyhow::Error;
349
350 fn from_str(s: &str) -> Result<Self, Self::Err> {
351 let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
352
353 let provider_lower = provider.to_lowercase();
354 match provider_lower.as_str() {
355 "sweep" => Ok(PredictionProvider::Sweep),
356 "mercury" => Ok(PredictionProvider::Mercury),
357 "zeta1" => Ok(PredictionProvider::Zeta1),
358 "zeta2" => {
359 let version = arg.map(ZetaVersion::parse).transpose()?.unwrap_or_default();
360 Ok(PredictionProvider::Zeta2(version))
361 }
362 "teacher" => {
363 let backend = arg
364 .map(|a| a.parse())
365 .transpose()?
366 .unwrap_or(TeacherBackend::Sonnet45);
367 Ok(PredictionProvider::Teacher(backend))
368 }
369 "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
370 let backend = arg
371 .map(|a| a.parse())
372 .transpose()?
373 .unwrap_or(TeacherBackend::Sonnet45);
374 Ok(PredictionProvider::TeacherNonBatching(backend))
375 }
376 "repair" => Ok(PredictionProvider::Repair),
377 _ => {
378 anyhow::bail!(
379 "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
380 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
381 For teacher, you can specify a backend like `teacher:sonnet45` or `teacher:gpt52`.\n\
382 Available zeta versions:\n{}",
383 ZetaVersion::options_as_string()
384 )
385 }
386 }
387 }
388}
389
390impl Serialize for PredictionProvider {
391 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
392 where
393 S: Serializer,
394 {
395 serializer.serialize_str(&self.to_string())
396 }
397}
398
399impl<'de> Deserialize<'de> for PredictionProvider {
400 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
401 where
402 D: Deserializer<'de>,
403 {
404 let s = String::deserialize(deserializer)?;
405 s.parse().map_err(serde::de::Error::custom)
406 }
407}
408
409#[derive(Debug, Args, Clone)]
410struct SynthesizeArgs {
411 /// Repository URLs (git@github.com:owner/repo or https://...)
412 #[clap(long, required = true, num_args = 1..)]
413 repos: Vec<String>,
414
415 /// Number of examples to generate per repository
416 #[clap(long, default_value_t = 5)]
417 count: usize,
418
419 /// Maximum commits to scan per repository before giving up
420 #[clap(long, default_value_t = 100)]
421 max_commits: usize,
422
423 /// Ignore state file and reprocess all commits
424 #[clap(long)]
425 fresh: bool,
426}
427
428#[derive(Debug, Args, Clone)]
429struct ImportBatchArgs {
430 /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
431 #[clap(long, required = true, num_args = 1..)]
432 batch_ids: Vec<String>,
433 /// Which provider's batches to import (anthropic or openai)
434 #[clap(long, default_value = "anthropic")]
435 provider: BatchProvider,
436}
437
438#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
439enum BatchProvider {
440 Anthropic,
441 Openai,
442}
443
444impl EpArgs {
445 fn output_path(&self) -> Option<PathBuf> {
446 if self.in_place {
447 if self.inputs.len() == 1 {
448 self.inputs.first().cloned()
449 } else {
450 panic!("--in-place requires exactly one input file")
451 }
452 } else {
453 self.output.clone()
454 }
455 }
456}
457
458async fn load_examples(
459 http_client: Arc<dyn http_client::HttpClient>,
460 args: &EpArgs,
461 output_path: Option<&PathBuf>,
462 background_executor: BackgroundExecutor,
463) -> anyhow::Result<Vec<Example>> {
464 let mut captured_after_timestamps = Vec::new();
465 let mut rejected_after_timestamps = Vec::new();
466 let mut requested_after_timestamps = Vec::new();
467 let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
468 Vec::new();
469 let mut file_inputs = Vec::new();
470
471 for input in &args.inputs {
472 let input_string = input.to_string_lossy();
473 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
474 captured_after_timestamps.push(timestamp.to_string());
475 } else if let Some(timestamp) =
476 pull_examples::parse_rejected_after_input(input_string.as_ref())
477 {
478 rejected_after_timestamps.push(timestamp.to_string());
479 } else if let Some(timestamp) =
480 pull_examples::parse_requested_after_input(input_string.as_ref())
481 {
482 requested_after_timestamps.push(timestamp.to_string());
483 } else if let Some((timestamp, rating_filter)) =
484 pull_examples::parse_rated_after_input(input_string.as_ref())
485 {
486 rated_after_inputs.push((timestamp.to_string(), rating_filter));
487 } else {
488 file_inputs.push(input.clone());
489 }
490 }
491
492 let mut examples = read_example_files(&file_inputs);
493
494 Progress::global().set_total_examples(examples.len());
495
496 let remaining_limit_for_snowflake =
497 args.limit.map(|limit| limit.saturating_sub(examples.len()));
498
499 if let Some(0) = remaining_limit_for_snowflake {
500 log::info!(
501 "skipping Snowflake inputs because --limit is already satisfied by example files"
502 );
503 } else {
504 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
505
506 if !captured_after_timestamps.is_empty() {
507 captured_after_timestamps.sort();
508
509 let mut captured_examples = pull_examples::fetch_captured_examples_after(
510 http_client.clone(),
511 &captured_after_timestamps,
512 max_rows_per_timestamp,
513 background_executor.clone(),
514 )
515 .await?;
516 examples.append(&mut captured_examples);
517 }
518
519 if !rejected_after_timestamps.is_empty() {
520 rejected_after_timestamps.sort();
521
522 let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
523 http_client.clone(),
524 &rejected_after_timestamps,
525 max_rows_per_timestamp,
526 background_executor.clone(),
527 )
528 .await?;
529 examples.append(&mut rejected_examples);
530 }
531
532 if !requested_after_timestamps.is_empty() {
533 requested_after_timestamps.sort();
534
535 let mut requested_examples = pull_examples::fetch_requested_examples_after(
536 http_client.clone(),
537 &requested_after_timestamps,
538 max_rows_per_timestamp,
539 background_executor.clone(),
540 )
541 .await?;
542 examples.append(&mut requested_examples);
543 }
544
545 if !rated_after_inputs.is_empty() {
546 rated_after_inputs.sort();
547
548 let mut rated_examples = pull_examples::fetch_rated_examples_after(
549 http_client,
550 &rated_after_inputs,
551 max_rows_per_timestamp,
552 background_executor,
553 )
554 .await?;
555 examples.append(&mut rated_examples);
556 }
557 }
558
559 crate::example::sort_examples_by_repo_and_rev(&mut examples);
560
561 if let Some(name_filter) = &args.name {
562 examples.retain(|example| example.spec.name.contains(name_filter));
563 }
564 if let Some(repo_filter) = &args.repo {
565 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
566 }
567
568 // Skip resume logic for --in-place since input and output are the same file,
569 // which would incorrectly treat all input examples as already processed.
570 if !args.in_place {
571 if let Some(path) = output_path
572 && let Some(command) = &args.command
573 {
574 resume_from_output(path, &mut examples, command);
575 }
576 }
577
578 if let Some(offset) = args.offset {
579 examples.splice(0..offset, []);
580 }
581
582 if let Some(limit) = args.limit {
583 examples.truncate(limit);
584 }
585
586 let progress = Progress::global();
587 progress.set_total_examples(examples.len());
588 progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
589
590 Ok(examples)
591}
592
593fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
594 let mut hasher = collections::FxHasher::default();
595 spec.hash(&mut hasher);
596 hasher.finish()
597}
598
599fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>, command: &Command) {
600 let file = match File::open(path) {
601 Ok(f) => f,
602 Err(_) => return,
603 };
604
605 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
606
607 let reader = BufReader::new(file);
608 let mut kept_lines = Vec::new();
609 let mut kept_hashes = HashSet::default();
610
611 for line in reader.lines() {
612 let line = match line {
613 Ok(l) => l,
614 Err(_) => continue,
615 };
616
617 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
618 let hash = spec_hash(&output_example.spec);
619 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
620 let is_complete = match command {
621 Command::Qa(_) => output_example
622 .qa
623 .first()
624 .and_then(|q| q.as_ref())
625 .and_then(|q| q.confidence)
626 .is_some(),
627 Command::Repair(_) => output_example.predictions.iter().any(|p| {
628 p.provider == PredictionProvider::Repair && p.actual_patch.is_some()
629 }),
630 _ => true,
631 };
632 if is_complete {
633 kept_hashes.insert(hash);
634 kept_lines.push(line);
635 }
636 }
637 }
638 }
639
640 let total = examples.len();
641 let already_processed = kept_hashes.len();
642
643 eprintln!(
644 "Resuming: {}/{} examples already processed",
645 already_processed, total
646 );
647
648 let file = OpenOptions::new()
649 .write(true)
650 .truncate(true)
651 .open(path)
652 .expect("Failed to open output file for rewriting");
653 let mut writer = BufWriter::new(file);
654 for line in &kept_lines {
655 writeln!(writer, "{}", line).expect("Failed to write to output file");
656 }
657 writer.flush().expect("Failed to flush output file");
658
659 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
660}
661
662fn main() {
663 let args = EpArgs::parse();
664
665 if args.printenv {
666 ::util::shell_env::print_env();
667 return;
668 }
669
670 let output = args.output_path();
671
672 if args.markdown && output.is_none() {
673 eprintln!("--markdown requires -o to specify the output directory");
674 std::process::exit(1);
675 }
676
677 let command = match &args.command {
678 Some(cmd) => cmd.clone(),
679 None => {
680 EpArgs::command().print_help().unwrap();
681 return;
682 }
683 };
684
685 match &command {
686 Command::ImportBatch(import_args) => {
687 smol::block_on(async {
688 match import_args.provider {
689 BatchProvider::Anthropic => {
690 let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
691 .expect("Failed to create Anthropic client");
692 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
693 eprintln!("Error importing Anthropic batches: {:?}", e);
694 std::process::exit(1);
695 }
696 }
697 BatchProvider::Openai => {
698 let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
699 .expect("Failed to create OpenAI client");
700 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
701 eprintln!("Error importing OpenAI batches: {:?}", e);
702 std::process::exit(1);
703 }
704 }
705 }
706 println!(
707 "Successfully imported {} batch(es)",
708 import_args.batch_ids.len()
709 );
710 });
711 return;
712 }
713 Command::Clean => {
714 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
715 return;
716 }
717 Command::Synthesize(synth_args) => {
718 let Some(output_dir) = args.output else {
719 panic!("output dir is required");
720 };
721 let config = SynthesizeConfig {
722 repo_urls: synth_args.repos.clone(),
723 count: synth_args.count,
724 max_commits: synth_args.max_commits,
725 output_dir,
726 fresh: synth_args.fresh,
727 };
728 smol::block_on(async {
729 if let Err(e) = run_synthesize(config).await {
730 eprintln!("Error: {:?}", e);
731 std::process::exit(1);
732 }
733 });
734 return;
735 }
736 Command::SplitCommit(split_commit_args) => {
737 if let Err(error) = split_commit::run_split_commit(
738 split_commit_args,
739 &args.inputs,
740 output.as_ref(),
741 args.failed,
742 ) {
743 eprintln!("{error:#}");
744 std::process::exit(1);
745 }
746 return;
747 }
748 Command::Split(split_args) => {
749 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
750 eprintln!("{error:#}");
751 std::process::exit(1);
752 }
753 return;
754 }
755 Command::FilterLanguages(filter_args) => {
756 if let Err(error) =
757 run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
758 {
759 eprintln!("{error:#}");
760 std::process::exit(1);
761 }
762 return;
763 }
764
765 _ => {}
766 }
767
768 let http_client = Arc::new(ReqwestClient::new());
769 let app = Application::headless().with_http_client(http_client);
770
771 app.run(move |cx| {
772 let app_state = Arc::new(headless::init(cx));
773 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
774
775 cx.spawn(async move |cx| {
776 let result = async {
777 let examples = load_examples(
778 app_state.client.http_client(),
779 &args,
780 output.as_ref(),
781 cx.background_executor().clone(),
782 )
783 .await?;
784
785 match &command {
786 Command::Predict(args) | Command::Score(args) => {
787 predict::sync_batches(args.provider.as_ref()).await?;
788 }
789 Command::Eval(args) => {
790 predict::sync_batches(args.predict.provider.as_ref()).await?;
791 }
792 Command::Qa(args) => {
793 qa::sync_batches(args).await?;
794 }
795 Command::Repair(args) => {
796 repair::sync_batches(args).await?;
797 }
798 _ => (),
799 }
800
801 let failfast_on_single_example = examples.len() == 1;
802
803 // For --markdown mode, create the output directory if it doesn't exist
804 if args.markdown {
805 let dir = output.as_ref().expect("--markdown requires -o");
806 if !dir.exists() {
807 std::fs::create_dir_all(dir)
808 .expect("Failed to create markdown output directory");
809 }
810 }
811
812 // Set up JSONL output writer (not used in markdown mode)
813 let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
814 let mut in_place_temp_path: Option<PathBuf> = None;
815 if !args.markdown
816 && let Some(output_path) = output.as_ref()
817 {
818 let write_path = if args.in_place {
819 let temp = output_path.with_extension("jsonl.tmp");
820 in_place_temp_path = Some(temp.clone());
821 temp
822 } else {
823 output_path.clone()
824 };
825
826 let file = OpenOptions::new()
827 .create(true)
828 .write(true)
829 .truncate(args.in_place)
830 .append(!args.in_place)
831 .open(&write_path)
832 .expect("Failed to open output file");
833
834 let mut writer = BufWriter::new(file);
835 let (sender, mut receiver) = mpsc::unbounded::<String>();
836 cx.background_spawn(async move {
837 while let Some(line) = receiver.next().await {
838 writeln!(writer, "{}", line).expect("Failed to write example");
839 writer.flush().expect("Failed to flush output");
840 }
841 })
842 .detach();
843 output_sender = Some(sender);
844 }
845
846 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
847 let finished_examples = Mutex::new(Vec::new());
848
849 let mut tasks = Vec::new();
850 for _ in 0..args.max_parallelism {
851 tasks.push(async {
852 loop {
853 let Some(mut repo_examples) =
854 grouped_examples.lock().unwrap().pop_front()
855 else {
856 break;
857 };
858 for example in &mut repo_examples {
859 let example_progress =
860 Progress::global().start_group(&example.spec.name);
861
862 let result = async {
863 match &command {
864 Command::Read => {}
865 Command::LoadProject => {
866 run_load_project(
867 example,
868 app_state.clone(),
869 &example_progress,
870 cx.clone(),
871 )
872 .await?;
873 }
874 Command::Context => {
875 run_context_retrieval(
876 example,
877 app_state.clone(),
878 &example_progress,
879 cx.clone(),
880 )
881 .await?;
882 }
883 Command::FormatPrompt(args) => {
884 run_format_prompt(
885 example,
886 args,
887 app_state.clone(),
888 &example_progress,
889 cx.clone(),
890 )
891 .await?;
892 }
893 Command::Predict(args) => {
894 run_prediction(
895 example,
896 args,
897 app_state.clone(),
898 &example_progress,
899 cx.clone(),
900 )
901 .await?;
902 }
903 Command::ParseOutput => {
904 parse_output::run_parse_output(example)?;
905 }
906 Command::Distill => {
907 run_distill(example).await?;
908 }
909 Command::Score(args) => {
910 run_scoring(
911 example,
912 args,
913 app_state.clone(),
914 &example_progress,
915 cx.clone(),
916 )
917 .await?;
918 }
919 Command::Eval(args) => {
920 run_scoring(
921 example,
922 &args.predict,
923 app_state.clone(),
924 &example_progress,
925 cx.clone(),
926 )
927 .await?;
928 }
929 Command::Qa(args) => {
930 qa::run_qa(example, args, &example_progress).await?;
931 }
932 Command::Repair(args) => {
933 repair::run_repair(example, args, &example_progress)
934 .await?;
935 }
936 Command::Clean
937 | Command::Synthesize(_)
938 | Command::SplitCommit(_)
939 | Command::Split(_)
940 | Command::FilterLanguages(_)
941 | Command::ImportBatch(_) => {
942 unreachable!()
943 }
944 }
945 anyhow::Ok(())
946 }
947 .await;
948
949 let failed = if let Err(error) = result {
950 handle_error(
951 error,
952 &args,
953 &command,
954 &app_state,
955 failfast_on_single_example,
956 &example,
957 )
958 .await;
959 true
960 } else {
961 false
962 };
963
964 let should_write = !failed || args.failed == FailedHandling::Keep;
965 if should_write {
966 if args.markdown {
967 let markdown_dir =
968 output.as_ref().expect("--markdown requires -o");
969 let filename = format!("{}.md", example.spec.filename());
970 let path = markdown_dir.join(&filename);
971 let markdown = example.spec.to_markdown();
972 std::fs::write(&path, &markdown)
973 .expect("Failed to write markdown file");
974 } else if let Some(ref mut sender) = output_sender.clone() {
975 let line = serde_json::to_string(&example).unwrap();
976 sender
977 .send(line)
978 .await
979 .expect("Failed to send to output writer");
980 } else if args.output.is_none()
981 && !matches!(command, Command::Eval(_))
982 {
983 let line = serde_json::to_string(&example).unwrap();
984 println!("{}", line);
985 }
986 }
987 }
988
989 let repo_url = &repo_examples.first().unwrap().spec.repository_url;
990 let project = repo_examples
991 .iter()
992 .find_map(|e| e.state.as_ref().map(|s| s.project.clone()))
993 .or_else(|| app_state.project_cache.get(repo_url));
994
995 if let Some(project) = project {
996 let mut cx = cx.clone();
997
998 let shutdown_task: Task<()> =
999 project.update(&mut cx, |project, cx| {
1000 let lsp_store = project.lsp_store();
1001 lsp_store.update(cx, |lsp_store, cx| {
1002 lsp_store.shutdown_all_language_servers(cx)
1003 })
1004 });
1005
1006 shutdown_task.await;
1007
1008 if let Some(ep_store) =
1009 cx.update(|cx| EditPredictionStore::try_global(cx))
1010 {
1011 ep_store.update(&mut cx, |store, _| {
1012 store.remove_project(&project);
1013 });
1014 }
1015 }
1016
1017 app_state.project_cache.remove(repo_url);
1018 for example in &mut repo_examples {
1019 example.state.take();
1020 }
1021 finished_examples
1022 .lock()
1023 .unwrap()
1024 .extend_from_slice(&repo_examples);
1025 }
1026 });
1027 }
1028 futures::future::join_all(tasks).await;
1029
1030 Progress::global().finalize();
1031
1032 match &command {
1033 Command::Predict(args) | Command::Score(args) => {
1034 predict::sync_batches(args.provider.as_ref()).await?;
1035 }
1036 Command::Eval(args) => {
1037 predict::sync_batches(args.predict.provider.as_ref()).await?;
1038 }
1039 Command::Qa(args) => {
1040 qa::sync_batches(args).await?;
1041 }
1042 Command::Repair(args) => {
1043 repair::sync_batches(args).await?;
1044 }
1045 _ => (),
1046 }
1047
1048 match &command {
1049 Command::Eval(args) => {
1050 let examples = finished_examples.lock().unwrap();
1051 score::print_report(&examples);
1052 if let Some(summary_path) = &args.summary_json {
1053 score::write_summary_json(&examples, summary_path)?;
1054 }
1055 }
1056 _ => (),
1057 };
1058
1059 // For --in-place, atomically rename temp file to original
1060 if let Some(temp_path) = &in_place_temp_path {
1061 let final_path = output.as_ref().expect("in_place_temp_path requires output");
1062 std::fs::rename(temp_path, final_path)
1063 .expect("Failed to rename temp file to final output");
1064 }
1065
1066 anyhow::Ok(())
1067 }
1068 .await;
1069
1070 if let Err(e) = result {
1071 panic!("Fatal error: {:?}", e);
1072 }
1073
1074 let _ = cx.update(|cx| cx.quit());
1075 })
1076 .detach();
1077 });
1078}
1079
1080async fn handle_error(
1081 error: anyhow::Error,
1082 args: &EpArgs,
1083 command: &Command,
1084 app_state: &Arc<headless::EpAppState>,
1085 failfast_on_single_example: bool,
1086 example: &Example,
1087) {
1088 Progress::global().increment_failed();
1089
1090 let msg;
1091 if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1092 let example_name = example.spec.filename();
1093
1094 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1095 app_state
1096 .fs
1097 .write(
1098 &failed_example_path,
1099 &serde_json::to_vec_pretty(&example).unwrap(),
1100 )
1101 .await
1102 .unwrap();
1103 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1104 app_state
1105 .fs
1106 .write(&err_path, format!("{error:?}").as_bytes())
1107 .await
1108 .unwrap();
1109
1110 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1111 let mut file = OpenOptions::new()
1112 .create(true)
1113 .append(true)
1114 .open(&failed_jsonl_path)
1115 .expect("Failed to open failed.jsonl");
1116 writeln!(file, "{}", serde_json::to_string(example).unwrap())
1117 .expect("Failed to write to failed.jsonl");
1118
1119 let cursor_path = match example.repo_name() {
1120 Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1121 Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1122 };
1123 msg = format!(
1124 indoc::indoc! {"
1125 While processing \"{}\":
1126
1127 \x1b[31m{:?}\x1b[0m
1128
1129 Example: \x1b[36m{}\x1b[0m
1130 Error file: \x1b[36m{}\x1b[0m
1131 Cursor file: \x1b[36m{}\x1b[0m
1132 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1133 "},
1134 example.spec.name,
1135 error,
1136 failed_example_path.display(),
1137 err_path.display(),
1138 cursor_path.display(),
1139 command,
1140 failed_example_path.display(),
1141 );
1142 } else {
1143 msg = format!(
1144 indoc::indoc! {"
1145 While processing \"{}\":
1146
1147 \x1b[31m{:?}\x1b[0m
1148 "},
1149 example.spec.name, error
1150 );
1151 }
1152
1153 if args.failfast || failfast_on_single_example {
1154 Progress::global().finalize();
1155 panic!("{}", msg);
1156 } else {
1157 log::error!("{}", msg);
1158 }
1159}