1mod anthropic_client;
2mod distill;
3mod example;
4mod filter_languages;
5mod format_prompt;
6mod git;
7mod headless;
8mod llm_client;
9mod load_project;
10mod metrics;
11mod openai_client;
12mod parse_output;
13mod paths;
14mod predict;
15mod progress;
16mod pull_examples;
17mod qa;
18mod reorder_patch;
19mod repair;
20mod retrieve_context;
21mod reversal_tracking;
22mod score;
23mod split_commit;
24mod split_dataset;
25mod synthesize;
26mod word_diff;
27use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
28use collections::HashSet;
29use edit_prediction::EditPredictionStore;
30use futures::channel::mpsc;
31use futures::{SinkExt as _, StreamExt as _};
32use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
33use zeta_prompt::ZetaVersion;
34
35use reqwest_client::ReqwestClient;
36use serde::{Deserialize, Deserializer, Serialize, Serializer};
37use std::fmt::Display;
38use std::fs::{File, OpenOptions};
39use std::hash::{Hash, Hasher};
40use std::io::{BufRead, BufReader, BufWriter, Write};
41use std::sync::Mutex;
42use std::{path::PathBuf, sync::Arc};
43
44use crate::distill::run_distill;
45use crate::example::{Example, group_examples_by_repo, read_example_files};
46use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
47use crate::format_prompt::run_format_prompt;
48use crate::load_project::run_load_project;
49use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
50use crate::predict::run_prediction;
51use crate::progress::Progress;
52use crate::retrieve_context::run_context_retrieval;
53use crate::score::run_scoring;
54use crate::split_commit::SplitCommitArgs;
55use crate::split_dataset::SplitArgs;
56use crate::synthesize::{SynthesizeConfig, run_synthesize};
57
58#[derive(Parser, Debug)]
59#[command(name = "ep")]
60struct EpArgs {
61 #[arg(long, default_value_t = false)]
62 printenv: bool,
63 #[clap(long, default_value_t = 10, global = true)]
64 max_parallelism: usize,
65 #[clap(long, global = true)]
66 limit: Option<usize>,
67 #[clap(long, global = true)]
68 offset: Option<usize>,
69 /// Filter examples by name
70 #[clap(long, global = true)]
71 name: Option<String>,
72 /// Filter examples by repository
73 #[clap(long, global = true)]
74 repo: Option<String>,
75 #[command(subcommand)]
76 command: Option<Command>,
77 #[clap(global = true, help = INPUTS_HELP)]
78 inputs: Vec<PathBuf>,
79 #[arg(long, short, global = true)]
80 output: Option<PathBuf>,
81 #[arg(long, short, global = true)]
82 in_place: bool,
83 #[arg(long, short, global = true)]
84 failfast: bool,
85 /// How to handle failed examples in output: keep them or skip them.
86 /// Failed examples are always logged to the run's failed directory.
87 #[arg(long, global = true, default_value = "keep")]
88 failed: FailedHandling,
89 /// Output as markdown files instead of JSONL. When set, -o specifies a directory
90 /// where one .md file per example will be written (named after each example).
91 #[arg(long, short, global = true)]
92 markdown: bool,
93}
94
95/// Controls whether failed examples are included in the main output.
96/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
97#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
98pub enum FailedHandling {
99 /// Include failed examples in the main output (default)
100 #[default]
101 Keep,
102 /// Exclude failed examples from the main output
103 Skip,
104 /// Skip writing files
105 SkipNoFiles,
106}
107
108const INPUTS_HELP: &str = r#"
109Inputs can be file paths or special specifiers:
110
111 path
112 Path to an example(s) file (.md, .json, or .jsonl)
113
114 captured-after:{timestamp}
115 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
116 These are examples captured via the "Capture Edit Prediction Example" action.
117
118 rejected-after:{timestamp}
119 Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
120 These are predictions that were shown to users but rejected (useful for DPO training).
121
122 Required environment variables to connect to Snowflake:
123 EP_SNOWFLAKE_API_KEY
124 EP_SNOWFLAKE_BASE_URL
125
126 Optional:
127 EP_SNOWFLAKE_ROLE
128
129Examples:
130
131 # Read examples from a file
132 ep read examples.jsonl -o output.jsonl
133
134 # Read captured examples after a timestamp
135 ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
136
137 # Read rejected predictions for DPO training
138 ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
139
140 # Mix multiple input sources
141 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
142"#;
143
144#[derive(Subcommand, Debug, Clone)]
145enum Command {
146 /// Read examples from files or fetch from Snowflake, output as .jsonl
147 Read,
148 /// Create git worktrees for each example and load file contents
149 LoadProject,
150 /// Retrieve context for input examples.
151 Context,
152 /// Generate a prompt string for a specific model
153 FormatPrompt(FormatPromptArgs),
154 /// Runs edit prediction
155 Predict(PredictArgs),
156 /// Parse model outputs (actual_output) into unified diffs (actual_patch).
157 /// Requires format-prompt to have been run first. Uses provider from prompt.
158 ParseOutput,
159 /// Computes a score based on actual and expected patches
160 Score(PredictArgs),
161 /// Prepares a distillation dataset by copying expected outputs to
162 /// predicted outputs and removing actual outputs and prompts.
163 Distill,
164 /// Print aggregated scores
165 Eval(EvalArgs),
166 /// Generate eval examples by analyzing git commits from a repository
167 Synthesize(SynthesizeArgs),
168 /// Remove git repositories and worktrees
169 Clean,
170 /// Generate an evaluation example by splitting a chronologically-ordered commit
171 SplitCommit(SplitCommitArgs),
172 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
173 Split(SplitArgs),
174 /// Filter a JSONL dataset by programming language (based on cursor_path extension)
175 FilterLanguages(FilterLanguagesArgs),
176 /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
177 ImportBatch(ImportBatchArgs),
178 /// Assess the quality of predictions using LLM-as-a-judge
179 Qa(qa::QaArgs),
180 /// Repair predictions that received poor QA scores by generating improved predictions
181 Repair(repair::RepairArgs),
182}
183
184impl Display for Command {
185 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
186 match self {
187 Command::Read => write!(f, "read"),
188 Command::LoadProject => write!(f, "load-project"),
189 Command::Context => write!(f, "context"),
190 Command::FormatPrompt(args) => {
191 write!(f, "format-prompt --provider={}", args.provider)
192 }
193 Command::Predict(args) => match &args.provider {
194 Some(provider) => write!(f, "predict --provider={}", provider),
195 None => write!(f, "predict"),
196 },
197 Command::ParseOutput => write!(f, "parse-output"),
198 Command::Score(args) => match &args.provider {
199 Some(provider) => write!(f, "score --provider={}", provider),
200 None => write!(f, "score"),
201 },
202 Command::Distill => write!(f, "distill"),
203 Command::Eval(args) => match &args.predict.provider {
204 Some(provider) => write!(f, "eval --provider={}", provider),
205 None => write!(f, "eval"),
206 },
207 Command::Synthesize(args) => {
208 write!(f, "synthesize --repos {}", args.repos.join(" "))
209 }
210 Command::Clean => write!(f, "clean"),
211 Command::SplitCommit(_) => write!(f, "split-commit"),
212 Command::Split(_) => write!(f, "split"),
213 Command::FilterLanguages(_) => write!(f, "filter-languages"),
214 Command::ImportBatch(args) => {
215 write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
216 }
217 Command::Qa(_) => {
218 write!(f, "qa")
219 }
220 Command::Repair(_) => {
221 write!(f, "repair")
222 }
223 }
224 }
225}
226
227#[derive(Debug, Args, Clone)]
228struct FormatPromptArgs {
229 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
230 provider: PredictionProvider,
231}
232
233#[derive(Debug, Args, Clone)]
234struct PredictArgs {
235 #[clap(long, short('p'))]
236 provider: Option<PredictionProvider>,
237 #[clap(long, default_value_t = 1)]
238 repetitions: usize,
239}
240
241#[derive(Debug, Args, Clone)]
242struct EvalArgs {
243 #[clap(flatten)]
244 predict: PredictArgs,
245 /// Path to write summary scores as JSON
246 #[clap(long)]
247 summary_json: Option<PathBuf>,
248}
249
250#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
251pub enum TeacherBackend {
252 Sonnet45,
253 Gpt52,
254}
255
256impl std::fmt::Display for TeacherBackend {
257 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258 match self {
259 TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
260 TeacherBackend::Gpt52 => write!(f, "gpt52"),
261 }
262 }
263}
264
265impl std::str::FromStr for TeacherBackend {
266 type Err = anyhow::Error;
267
268 fn from_str(s: &str) -> Result<Self, Self::Err> {
269 match s.to_lowercase().as_str() {
270 "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
271 "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
272 "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
273 _ => anyhow::bail!("unknown teacher backend `{s}`. Valid options: sonnet45, gpt52"),
274 }
275 }
276}
277
278impl TeacherBackend {
279 pub fn model_name(&self) -> &'static str {
280 match self {
281 TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
282 TeacherBackend::Gpt52 => "gpt-5.2",
283 }
284 }
285}
286
287#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
288enum PredictionProvider {
289 Sweep,
290 Mercury,
291 Zeta1,
292 Zeta2(ZetaVersion),
293 Teacher(TeacherBackend),
294 TeacherNonBatching(TeacherBackend),
295 RepairedTeacher(TeacherBackend),
296 Repair,
297}
298
299impl Default for PredictionProvider {
300 fn default() -> Self {
301 PredictionProvider::Zeta2(ZetaVersion::default())
302 }
303}
304
305impl std::fmt::Display for PredictionProvider {
306 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
307 match self {
308 PredictionProvider::Sweep => write!(f, "sweep"),
309 PredictionProvider::Mercury => write!(f, "mercury"),
310 PredictionProvider::Zeta1 => write!(f, "zeta1"),
311 PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
312 PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
313 PredictionProvider::TeacherNonBatching(backend) => {
314 write!(f, "teacher-non-batching:{backend}")
315 }
316 PredictionProvider::RepairedTeacher(backend) => {
317 write!(f, "repaired-teacher:{backend}")
318 }
319 PredictionProvider::Repair => write!(f, "repair"),
320 }
321 }
322}
323
324impl std::str::FromStr for PredictionProvider {
325 type Err = anyhow::Error;
326
327 fn from_str(s: &str) -> Result<Self, Self::Err> {
328 let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
329
330 let provider_lower = provider.to_lowercase();
331 match provider_lower.as_str() {
332 "sweep" => Ok(PredictionProvider::Sweep),
333 "mercury" => Ok(PredictionProvider::Mercury),
334 "zeta1" => Ok(PredictionProvider::Zeta1),
335 "zeta2" => {
336 let version = arg.map(ZetaVersion::parse).transpose()?.unwrap_or_default();
337 Ok(PredictionProvider::Zeta2(version))
338 }
339 "teacher" => {
340 let backend = arg
341 .map(|a| a.parse())
342 .transpose()?
343 .unwrap_or(TeacherBackend::Sonnet45);
344 Ok(PredictionProvider::Teacher(backend))
345 }
346 "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
347 let backend = arg
348 .map(|a| a.parse())
349 .transpose()?
350 .unwrap_or(TeacherBackend::Sonnet45);
351 Ok(PredictionProvider::TeacherNonBatching(backend))
352 }
353 "repaired-teacher" | "repaired_teacher" | "repairedteacher" => {
354 let backend = arg
355 .map(|a| a.parse())
356 .transpose()?
357 .unwrap_or(TeacherBackend::Sonnet45);
358 Ok(PredictionProvider::RepairedTeacher(backend))
359 }
360 "repair" => Ok(PredictionProvider::Repair),
361 _ => {
362 anyhow::bail!(
363 "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repaired-teacher, repair\n\
364 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
365 For teacher, you can specify a backend like `teacher:sonnet45` or `teacher:gpt52`.\n\
366 Available zeta versions:\n{}",
367 ZetaVersion::options_as_string()
368 )
369 }
370 }
371 }
372}
373
374impl Serialize for PredictionProvider {
375 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
376 where
377 S: Serializer,
378 {
379 serializer.serialize_str(&self.to_string())
380 }
381}
382
383impl<'de> Deserialize<'de> for PredictionProvider {
384 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
385 where
386 D: Deserializer<'de>,
387 {
388 let s = String::deserialize(deserializer)?;
389 s.parse().map_err(serde::de::Error::custom)
390 }
391}
392
393#[derive(Debug, Args, Clone)]
394struct SynthesizeArgs {
395 /// Repository URLs (git@github.com:owner/repo or https://...)
396 #[clap(long, required = true, num_args = 1..)]
397 repos: Vec<String>,
398
399 /// Number of examples to generate per repository
400 #[clap(long, default_value_t = 5)]
401 count: usize,
402
403 /// Maximum commits to scan per repository before giving up
404 #[clap(long, default_value_t = 100)]
405 max_commits: usize,
406
407 /// Ignore state file and reprocess all commits
408 #[clap(long)]
409 fresh: bool,
410}
411
412#[derive(Debug, Args, Clone)]
413struct ImportBatchArgs {
414 /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
415 #[clap(long, required = true, num_args = 1..)]
416 batch_ids: Vec<String>,
417 /// Which provider's batches to import (anthropic or openai)
418 #[clap(long, default_value = "anthropic")]
419 provider: BatchProvider,
420}
421
422#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
423enum BatchProvider {
424 Anthropic,
425 Openai,
426}
427
428impl EpArgs {
429 fn output_path(&self) -> Option<PathBuf> {
430 if self.in_place {
431 if self.inputs.len() == 1 {
432 self.inputs.first().cloned()
433 } else {
434 panic!("--in-place requires exactly one input file")
435 }
436 } else {
437 self.output.clone()
438 }
439 }
440}
441
442async fn load_examples(
443 http_client: Arc<dyn http_client::HttpClient>,
444 args: &EpArgs,
445 output_path: Option<&PathBuf>,
446 background_executor: BackgroundExecutor,
447) -> anyhow::Result<Vec<Example>> {
448 let mut captured_after_timestamps = Vec::new();
449 let mut rejected_after_timestamps = Vec::new();
450 let mut requested_after_timestamps = Vec::new();
451 let mut file_inputs = Vec::new();
452
453 for input in &args.inputs {
454 let input_string = input.to_string_lossy();
455 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
456 captured_after_timestamps.push(timestamp.to_string());
457 } else if let Some(timestamp) =
458 pull_examples::parse_rejected_after_input(input_string.as_ref())
459 {
460 rejected_after_timestamps.push(timestamp.to_string());
461 } else if let Some(timestamp) =
462 pull_examples::parse_requested_after_input(input_string.as_ref())
463 {
464 requested_after_timestamps.push(timestamp.to_string());
465 } else {
466 file_inputs.push(input.clone());
467 }
468 }
469
470 let mut examples = read_example_files(&file_inputs);
471
472 Progress::global().set_total_examples(examples.len());
473
474 let remaining_limit_for_snowflake =
475 args.limit.map(|limit| limit.saturating_sub(examples.len()));
476
477 if let Some(0) = remaining_limit_for_snowflake {
478 log::info!(
479 "skipping Snowflake inputs because --limit is already satisfied by example files"
480 );
481 } else {
482 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
483
484 if !captured_after_timestamps.is_empty() {
485 captured_after_timestamps.sort();
486
487 let mut captured_examples = pull_examples::fetch_captured_examples_after(
488 http_client.clone(),
489 &captured_after_timestamps,
490 max_rows_per_timestamp,
491 background_executor.clone(),
492 )
493 .await?;
494 examples.append(&mut captured_examples);
495 }
496
497 if !rejected_after_timestamps.is_empty() {
498 rejected_after_timestamps.sort();
499
500 let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
501 http_client.clone(),
502 &rejected_after_timestamps,
503 max_rows_per_timestamp,
504 background_executor.clone(),
505 )
506 .await?;
507 examples.append(&mut rejected_examples);
508 }
509
510 if !requested_after_timestamps.is_empty() {
511 requested_after_timestamps.sort();
512
513 let mut requested_examples = pull_examples::fetch_requested_examples_after(
514 http_client,
515 &requested_after_timestamps,
516 max_rows_per_timestamp,
517 background_executor,
518 )
519 .await?;
520 examples.append(&mut requested_examples);
521 }
522 }
523
524 crate::example::sort_examples_by_repo_and_rev(&mut examples);
525
526 if let Some(name_filter) = &args.name {
527 examples.retain(|example| example.spec.name.contains(name_filter));
528 }
529 if let Some(repo_filter) = &args.repo {
530 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
531 }
532
533 // Skip resume logic for --in-place since input and output are the same file,
534 // which would incorrectly treat all input examples as already processed.
535 if !args.in_place {
536 if let Some(path) = output_path {
537 resume_from_output(path, &mut examples);
538 }
539 }
540
541 if let Some(offset) = args.offset {
542 examples.splice(0..offset, []);
543 }
544
545 if let Some(limit) = args.limit {
546 examples.truncate(limit);
547 }
548
549 let progress = Progress::global();
550 progress.set_total_examples(examples.len());
551 progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
552
553 Ok(examples)
554}
555
556fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
557 let mut hasher = collections::FxHasher::default();
558 spec.hash(&mut hasher);
559 hasher.finish()
560}
561
562fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
563 let file = match File::open(path) {
564 Ok(f) => f,
565 Err(_) => return,
566 };
567
568 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
569
570 let reader = BufReader::new(file);
571 let mut kept_lines = Vec::new();
572 let mut kept_hashes = HashSet::default();
573
574 for line in reader.lines() {
575 let line = match line {
576 Ok(l) => l,
577 Err(_) => continue,
578 };
579
580 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
581 let hash = spec_hash(&output_example.spec);
582 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
583 kept_hashes.insert(hash);
584 kept_lines.push(line);
585 }
586 }
587 }
588
589 let total = examples.len();
590 let already_processed = kept_hashes.len();
591
592 eprintln!(
593 "Resuming: {}/{} examples already processed",
594 already_processed, total
595 );
596
597 let file = OpenOptions::new()
598 .write(true)
599 .truncate(true)
600 .open(path)
601 .expect("Failed to open output file for rewriting");
602 let mut writer = BufWriter::new(file);
603 for line in &kept_lines {
604 writeln!(writer, "{}", line).expect("Failed to write to output file");
605 }
606 writer.flush().expect("Failed to flush output file");
607
608 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
609}
610
611fn main() {
612 let args = EpArgs::parse();
613
614 if args.printenv {
615 ::util::shell_env::print_env();
616 return;
617 }
618
619 let output = args.output_path();
620
621 if args.markdown && output.is_none() {
622 eprintln!("--markdown requires -o to specify the output directory");
623 std::process::exit(1);
624 }
625
626 let command = match &args.command {
627 Some(cmd) => cmd.clone(),
628 None => {
629 EpArgs::command().print_help().unwrap();
630 return;
631 }
632 };
633
634 match &command {
635 Command::ImportBatch(import_args) => {
636 smol::block_on(async {
637 match import_args.provider {
638 BatchProvider::Anthropic => {
639 let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
640 .expect("Failed to create Anthropic client");
641 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
642 eprintln!("Error importing Anthropic batches: {:?}", e);
643 std::process::exit(1);
644 }
645 }
646 BatchProvider::Openai => {
647 let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
648 .expect("Failed to create OpenAI client");
649 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
650 eprintln!("Error importing OpenAI batches: {:?}", e);
651 std::process::exit(1);
652 }
653 }
654 }
655 println!(
656 "Successfully imported {} batch(es)",
657 import_args.batch_ids.len()
658 );
659 });
660 return;
661 }
662 Command::Clean => {
663 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
664 return;
665 }
666 Command::Synthesize(synth_args) => {
667 let Some(output_dir) = args.output else {
668 panic!("output dir is required");
669 };
670 let config = SynthesizeConfig {
671 repo_urls: synth_args.repos.clone(),
672 count: synth_args.count,
673 max_commits: synth_args.max_commits,
674 output_dir,
675 fresh: synth_args.fresh,
676 };
677 smol::block_on(async {
678 if let Err(e) = run_synthesize(config).await {
679 eprintln!("Error: {:?}", e);
680 std::process::exit(1);
681 }
682 });
683 return;
684 }
685 Command::SplitCommit(split_commit_args) => {
686 if let Err(error) = split_commit::run_split_commit(
687 split_commit_args,
688 &args.inputs,
689 output.as_ref(),
690 args.failed,
691 ) {
692 eprintln!("{error:#}");
693 std::process::exit(1);
694 }
695 return;
696 }
697 Command::Split(split_args) => {
698 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
699 eprintln!("{error:#}");
700 std::process::exit(1);
701 }
702 return;
703 }
704 Command::FilterLanguages(filter_args) => {
705 if let Err(error) =
706 run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
707 {
708 eprintln!("{error:#}");
709 std::process::exit(1);
710 }
711 return;
712 }
713 Command::Qa(qa_args) => {
714 // Read examples from input files
715 let mut examples = example::read_example_files(&args.inputs);
716
717 // Apply filters
718 if let Some(name_filter) = &args.name {
719 examples.retain(|e| e.spec.name.contains(name_filter));
720 }
721 if let Some(repo_filter) = &args.repo {
722 examples.retain(|e| e.spec.repository_url.contains(repo_filter));
723 }
724 if let Some(offset) = args.offset {
725 examples.splice(0..offset, []);
726 }
727 if let Some(limit) = args.limit {
728 examples.truncate(limit);
729 }
730
731 smol::block_on(async {
732 if let Err(e) = qa::run_qa(&mut examples, qa_args, output.as_ref()).await {
733 eprintln!("Error: {:?}", e);
734 std::process::exit(1);
735 }
736 });
737 return;
738 }
739 Command::Repair(repair_args) => {
740 // Read examples from input files
741 let mut examples = example::read_example_files(&args.inputs);
742
743 // Apply filters
744 if let Some(name_filter) = &args.name {
745 examples.retain(|e| e.spec.name.contains(name_filter));
746 }
747 if let Some(repo_filter) = &args.repo {
748 examples.retain(|e| e.spec.repository_url.contains(repo_filter));
749 }
750 if let Some(offset) = args.offset {
751 examples.splice(0..offset, []);
752 }
753 if let Some(limit) = args.limit {
754 examples.truncate(limit);
755 }
756
757 smol::block_on(async {
758 if let Err(e) =
759 repair::run_repair(&mut examples, repair_args, output.as_ref()).await
760 {
761 eprintln!("Error: {:?}", e);
762 std::process::exit(1);
763 }
764 });
765 return;
766 }
767 _ => {}
768 }
769
770 let http_client = Arc::new(ReqwestClient::new());
771 let app = Application::headless().with_http_client(http_client);
772
773 app.run(move |cx| {
774 let app_state = Arc::new(headless::init(cx));
775 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
776
777 cx.spawn(async move |cx| {
778 let result = async {
779 let examples = load_examples(
780 app_state.client.http_client(),
781 &args,
782 output.as_ref(),
783 cx.background_executor().clone(),
784 )
785 .await?;
786
787 match &command {
788 Command::Predict(args) | Command::Score(args) => {
789 predict::sync_batches(args.provider.as_ref()).await?;
790 }
791 Command::Eval(args) => {
792 predict::sync_batches(args.predict.provider.as_ref()).await?;
793 }
794 _ => (),
795 }
796
797 let failfast_on_single_example = examples.len() == 1;
798
799 // For --markdown mode, create the output directory if it doesn't exist
800 let markdown_output_dir = if args.markdown {
801 let dir = output.as_ref().expect("--markdown requires -o");
802 if !dir.exists() {
803 std::fs::create_dir_all(dir)
804 .expect("Failed to create markdown output directory");
805 }
806 Some(dir.clone())
807 } else {
808 None
809 };
810
811 // For --in-place, write to a temp file and rename at the end to avoid data loss on interruption
812 let in_place_temp_path = if args.in_place {
813 output.as_ref().map(|path| {
814 let mut temp_path = path.clone();
815 temp_path.set_extension("jsonl.tmp");
816 temp_path
817 })
818 } else {
819 None
820 };
821
822 let output_sender: Option<mpsc::UnboundedSender<String>> = if !args.markdown
823 && (args.output.is_some() || !matches!(command, Command::Eval(_)))
824 {
825 let write_path = in_place_temp_path.as_ref().or(output.as_ref());
826 write_path.map(|path| {
827 let file = if args.in_place {
828 // For --in-place, write to temp file (truncate if exists)
829 OpenOptions::new()
830 .create(true)
831 .write(true)
832 .truncate(true)
833 .open(path)
834 .expect("Failed to open temp output file")
835 } else {
836 // For regular output, append to support resuming
837 OpenOptions::new()
838 .create(true)
839 .append(true)
840 .open(path)
841 .expect("Failed to open output file")
842 };
843 let mut writer = BufWriter::new(file);
844 let (sender, mut receiver) = mpsc::unbounded::<String>();
845 cx.background_spawn(async move {
846 while let Some(line) = receiver.next().await {
847 writeln!(writer, "{}", line).expect("Failed to write example");
848 writer.flush().expect("Failed to flush output");
849 }
850 })
851 .detach();
852 sender
853 })
854 } else {
855 None
856 };
857
858 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
859 let finished_examples = Mutex::new(Vec::new());
860
861 let mut tasks = Vec::new();
862 for _ in 0..args.max_parallelism {
863 tasks.push(async {
864 loop {
865 let Some(mut repo_examples) =
866 grouped_examples.lock().unwrap().pop_front()
867 else {
868 break;
869 };
870 for example in &mut repo_examples {
871 let example_progress =
872 Progress::global().start_group(&example.spec.name);
873
874 let result = async {
875 match &command {
876 Command::Read => {}
877 Command::LoadProject => {
878 run_load_project(
879 example,
880 app_state.clone(),
881 &example_progress,
882 cx.clone(),
883 )
884 .await?;
885 }
886 Command::Context => {
887 run_context_retrieval(
888 example,
889 app_state.clone(),
890 &example_progress,
891 cx.clone(),
892 )
893 .await?;
894 }
895 Command::FormatPrompt(args) => {
896 run_format_prompt(
897 example,
898 args,
899 app_state.clone(),
900 &example_progress,
901 cx.clone(),
902 )
903 .await?;
904 }
905 Command::Predict(args) => {
906 run_prediction(
907 example,
908 args,
909 app_state.clone(),
910 &example_progress,
911 cx.clone(),
912 )
913 .await?;
914 }
915 Command::ParseOutput => {
916 parse_output::run_parse_output(example)?;
917 }
918 Command::Distill => {
919 run_distill(example).await?;
920 }
921 Command::Score(args) => {
922 run_scoring(
923 example,
924 args,
925 app_state.clone(),
926 &example_progress,
927 cx.clone(),
928 )
929 .await?;
930 }
931 Command::Eval(args) => {
932 run_scoring(
933 example,
934 &args.predict,
935 app_state.clone(),
936 &example_progress,
937 cx.clone(),
938 )
939 .await?;
940 }
941 Command::Clean
942 | Command::Synthesize(_)
943 | Command::SplitCommit(_)
944 | Command::Split(_)
945 | Command::FilterLanguages(_)
946 | Command::ImportBatch(_)
947 | Command::Qa(_)
948 | Command::Repair(_) => {
949 unreachable!()
950 }
951 }
952 anyhow::Ok(())
953 }
954 .await;
955
956 let failed = if let Err(error) = result {
957 handle_error(
958 error,
959 &args,
960 &command,
961 &app_state,
962 failfast_on_single_example,
963 &example,
964 )
965 .await;
966 true
967 } else {
968 false
969 };
970
971 let should_write = !failed || args.failed == FailedHandling::Keep;
972 if should_write {
973 if let Some(ref markdown_dir) = markdown_output_dir {
974 let filename = format!("{}.md", example.spec.filename());
975 let path = markdown_dir.join(&filename);
976 let markdown = example.spec.to_markdown();
977 std::fs::write(&path, &markdown)
978 .expect("Failed to write markdown file");
979 } else if let Some(ref mut sender) = output_sender.clone() {
980 let line = serde_json::to_string(&example).unwrap();
981 sender
982 .send(line)
983 .await
984 .expect("Failed to send to output writer");
985 } else if args.output.is_none()
986 && !matches!(command, Command::Eval(_))
987 {
988 let line = serde_json::to_string(&example).unwrap();
989 println!("{}", line);
990 }
991 }
992 }
993
994 let repo_url = &repo_examples.first().unwrap().spec.repository_url;
995 let project = repo_examples
996 .iter()
997 .find_map(|e| e.state.as_ref().map(|s| s.project.clone()))
998 .or_else(|| app_state.project_cache.get(repo_url));
999
1000 if let Some(project) = project {
1001 let mut cx = cx.clone();
1002
1003 let shutdown_task: Task<()> =
1004 project.update(&mut cx, |project, cx| {
1005 let lsp_store = project.lsp_store();
1006 lsp_store.update(cx, |lsp_store, cx| {
1007 lsp_store.shutdown_all_language_servers(cx)
1008 })
1009 });
1010
1011 shutdown_task.await;
1012
1013 if let Some(ep_store) =
1014 cx.update(|cx| EditPredictionStore::try_global(cx))
1015 {
1016 ep_store.update(&mut cx, |store, _| {
1017 store.remove_project(&project);
1018 });
1019 }
1020 }
1021
1022 app_state.project_cache.remove(repo_url);
1023 for example in &mut repo_examples {
1024 example.state.take();
1025 }
1026 finished_examples
1027 .lock()
1028 .unwrap()
1029 .extend_from_slice(&repo_examples);
1030 }
1031 });
1032 }
1033 futures::future::join_all(tasks).await;
1034
1035 Progress::global().finalize();
1036
1037 match &command {
1038 Command::Predict(args) | Command::Score(args) => {
1039 predict::sync_batches(args.provider.as_ref()).await?;
1040 }
1041 Command::Eval(args) => {
1042 predict::sync_batches(args.predict.provider.as_ref()).await?;
1043 }
1044 _ => (),
1045 }
1046
1047 match &command {
1048 Command::Eval(args) => {
1049 let examples = finished_examples.lock().unwrap();
1050 score::print_report(&examples);
1051 if let Some(summary_path) = &args.summary_json {
1052 score::write_summary_json(&examples, summary_path)?;
1053 }
1054 }
1055 _ => (),
1056 };
1057
1058 // For --in-place, atomically rename temp file to original
1059 if let (Some(temp_path), Some(final_path)) = (&in_place_temp_path, &output) {
1060 std::fs::rename(temp_path, final_path)
1061 .expect("Failed to rename temp file to final output");
1062 }
1063
1064 anyhow::Ok(())
1065 }
1066 .await;
1067
1068 if let Err(e) = result {
1069 panic!("Fatal error: {:?}", e);
1070 }
1071
1072 let _ = cx.update(|cx| cx.quit());
1073 })
1074 .detach();
1075 });
1076}
1077
1078async fn handle_error(
1079 error: anyhow::Error,
1080 args: &EpArgs,
1081 command: &Command,
1082 app_state: &Arc<headless::EpAppState>,
1083 failfast_on_single_example: bool,
1084 example: &Example,
1085) {
1086 Progress::global().increment_failed();
1087
1088 let msg;
1089 if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1090 let example_name = example.spec.filename();
1091
1092 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1093 app_state
1094 .fs
1095 .write(
1096 &failed_example_path,
1097 &serde_json::to_vec_pretty(&example).unwrap(),
1098 )
1099 .await
1100 .unwrap();
1101 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1102 app_state
1103 .fs
1104 .write(&err_path, format!("{error:?}").as_bytes())
1105 .await
1106 .unwrap();
1107
1108 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1109 let mut file = OpenOptions::new()
1110 .create(true)
1111 .append(true)
1112 .open(&failed_jsonl_path)
1113 .expect("Failed to open failed.jsonl");
1114 writeln!(file, "{}", serde_json::to_string(example).unwrap())
1115 .expect("Failed to write to failed.jsonl");
1116
1117 let cursor_path = example
1118 .repo_name()
1119 .unwrap()
1120 .worktree_path()
1121 .join(&example.spec.cursor_path);
1122 msg = format!(
1123 indoc::indoc! {"
1124 While processing \"{}\":
1125
1126 \x1b[31m{:?}\x1b[0m
1127
1128 Example: \x1b[36m{}\x1b[0m
1129 Error file: \x1b[36m{}\x1b[0m
1130 Cursor file: \x1b[36m{}\x1b[0m
1131 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1132 "},
1133 example.spec.name,
1134 error,
1135 failed_example_path.display(),
1136 err_path.display(),
1137 cursor_path.display(),
1138 command,
1139 failed_example_path.display(),
1140 );
1141 } else {
1142 msg = format!(
1143 indoc::indoc! {"
1144 While processing \"{}\":
1145
1146 \x1b[31m{:?}\x1b[0m
1147 "},
1148 example.spec.name, error
1149 );
1150 }
1151
1152 if args.failfast || failfast_on_single_example {
1153 Progress::global().finalize();
1154 panic!("{}", msg);
1155 } else {
1156 log::error!("{}", msg);
1157 }
1158}