1mod anthropic_client;
2mod distill;
3mod example;
4mod filter_languages;
5mod format_prompt;
6mod git;
7mod headless;
8mod load_project;
9mod metrics;
10mod parse_output;
11mod paths;
12mod predict;
13mod progress;
14mod pull_examples;
15mod qa;
16mod reorder_patch;
17mod retrieve_context;
18mod score;
19mod split_commit;
20mod split_dataset;
21mod synthesize;
22mod word_diff;
23use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
24use collections::HashSet;
25use edit_prediction::EditPredictionStore;
26use futures::channel::mpsc;
27use futures::{SinkExt as _, StreamExt as _};
28use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
29use zeta_prompt::ZetaVersion;
30
31use reqwest_client::ReqwestClient;
32use serde::{Deserialize, Deserializer, Serialize, Serializer};
33use std::fmt::Display;
34use std::fs::{File, OpenOptions};
35use std::hash::{Hash, Hasher};
36use std::io::{BufRead, BufReader, BufWriter, Write};
37use std::sync::Mutex;
38use std::{path::PathBuf, sync::Arc};
39
40use crate::distill::run_distill;
41use crate::example::{Example, group_examples_by_repo, read_example_files};
42use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
43use crate::format_prompt::run_format_prompt;
44use crate::load_project::run_load_project;
45use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
46use crate::predict::run_prediction;
47use crate::progress::Progress;
48use crate::retrieve_context::run_context_retrieval;
49use crate::score::run_scoring;
50use crate::split_commit::SplitCommitArgs;
51use crate::split_dataset::SplitArgs;
52use crate::synthesize::{SynthesizeConfig, run_synthesize};
53
54#[derive(Parser, Debug)]
55#[command(name = "ep")]
56struct EpArgs {
57 #[arg(long, default_value_t = false)]
58 printenv: bool,
59 #[clap(long, default_value_t = 10, global = true)]
60 max_parallelism: usize,
61 #[clap(long, global = true)]
62 limit: Option<usize>,
63 #[clap(long, global = true)]
64 offset: Option<usize>,
65 /// Filter examples by name
66 #[clap(long, global = true)]
67 name: Option<String>,
68 /// Filter examples by repository
69 #[clap(long, global = true)]
70 repo: Option<String>,
71 #[command(subcommand)]
72 command: Option<Command>,
73 #[clap(global = true, help = INPUTS_HELP)]
74 inputs: Vec<PathBuf>,
75 #[arg(long, short, global = true)]
76 output: Option<PathBuf>,
77 #[arg(long, short, global = true)]
78 in_place: bool,
79 #[arg(long, short, global = true)]
80 failfast: bool,
81 /// How to handle failed examples in output: keep them or skip them.
82 /// Failed examples are always logged to the run's failed directory.
83 #[arg(long, global = true, default_value = "keep")]
84 failed: FailedHandling,
85}
86
87/// Controls whether failed examples are included in the main output.
88/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
89#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
90pub enum FailedHandling {
91 /// Include failed examples in the main output (default)
92 #[default]
93 Keep,
94 /// Exclude failed examples from the main output
95 Skip,
96 /// Skip writing files
97 SkipNoFiles,
98}
99
100const INPUTS_HELP: &str = r#"
101Inputs can be file paths or special specifiers:
102
103 path
104 Path to an example(s) file (.md, .json, or .jsonl)
105
106 captured-after:{timestamp}
107 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
108 These are examples captured via the "Capture Edit Prediction Example" action.
109
110 rejected-after:{timestamp}
111 Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
112 These are predictions that were shown to users but rejected (useful for DPO training).
113
114 Required environment variables to connect to Snowflake:
115 EP_SNOWFLAKE_API_KEY
116 EP_SNOWFLAKE_BASE_URL
117
118 Optional:
119 EP_SNOWFLAKE_ROLE
120
121Examples:
122
123 # Read examples from a file
124 ep read examples.jsonl -o output.jsonl
125
126 # Read captured examples after a timestamp
127 ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
128
129 # Read rejected predictions for DPO training
130 ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
131
132 # Mix multiple input sources
133 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
134"#;
135
136#[derive(Subcommand, Debug, Clone)]
137enum Command {
138 /// Read examples from files or fetch from Snowflake, output as .jsonl
139 Read,
140 /// Create git worktrees for each example and load file contents
141 LoadProject,
142 /// Retrieve context for input examples.
143 Context,
144 /// Generate a prompt string for a specific model
145 FormatPrompt(FormatPromptArgs),
146 /// Runs edit prediction
147 Predict(PredictArgs),
148 /// Parse model outputs (actual_output) into unified diffs (actual_patch).
149 /// Requires format-prompt to have been run first. Uses provider from prompt.
150 ParseOutput,
151 /// Computes a score based on actual and expected patches
152 Score(PredictArgs),
153 /// Prepares a distillation dataset by copying expected outputs to
154 /// predicted outputs and removing actual outputs and prompts.
155 Distill,
156 /// Print aggregated scores
157 Eval(EvalArgs),
158 /// Generate eval examples by analyzing git commits from a repository
159 Synthesize(SynthesizeArgs),
160 /// Remove git repositories and worktrees
161 Clean,
162 /// Generate an evaluation example by splitting a chronologically-ordered commit
163 SplitCommit(SplitCommitArgs),
164 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
165 Split(SplitArgs),
166 /// Filter a JSONL dataset by programming language (based on cursor_path extension)
167 FilterLanguages(FilterLanguagesArgs),
168 /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
169 ImportBatch(ImportBatchArgs),
170 /// Assess the quality of predictions using LLM-as-a-judge
171 Qa(qa::QaArgs),
172}
173
174impl Display for Command {
175 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
176 match self {
177 Command::Read => write!(f, "read"),
178 Command::LoadProject => write!(f, "load-project"),
179 Command::Context => write!(f, "context"),
180 Command::FormatPrompt(args) => {
181 write!(f, "format-prompt --provider={}", args.provider)
182 }
183 Command::Predict(args) => match &args.provider {
184 Some(provider) => write!(f, "predict --provider={}", provider),
185 None => write!(f, "predict"),
186 },
187 Command::ParseOutput => write!(f, "parse-output"),
188 Command::Score(args) => match &args.provider {
189 Some(provider) => write!(f, "score --provider={}", provider),
190 None => write!(f, "score"),
191 },
192 Command::Distill => write!(f, "distill"),
193 Command::Eval(args) => match &args.predict.provider {
194 Some(provider) => write!(f, "eval --provider={}", provider),
195 None => write!(f, "eval"),
196 },
197 Command::Synthesize(args) => {
198 write!(f, "synthesize --repos {}", args.repos.join(" "))
199 }
200 Command::Clean => write!(f, "clean"),
201 Command::SplitCommit(_) => write!(f, "split-commit"),
202 Command::Split(_) => write!(f, "split"),
203 Command::FilterLanguages(_) => write!(f, "filter-languages"),
204 Command::ImportBatch(args) => {
205 write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
206 }
207 Command::Qa(_) => {
208 write!(f, "qa")
209 }
210 }
211 }
212}
213
214#[derive(Debug, Args, Clone)]
215struct FormatPromptArgs {
216 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
217 provider: PredictionProvider,
218}
219
220#[derive(Debug, Args, Clone)]
221struct PredictArgs {
222 #[clap(long, short('p'))]
223 provider: Option<PredictionProvider>,
224 #[clap(long, default_value_t = 1)]
225 repetitions: usize,
226}
227
228#[derive(Debug, Args, Clone)]
229struct EvalArgs {
230 #[clap(flatten)]
231 predict: PredictArgs,
232 /// Path to write summary scores as JSON
233 #[clap(long)]
234 summary_json: Option<PathBuf>,
235}
236
237#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
238enum PredictionProvider {
239 Sweep,
240 Mercury,
241 Zeta1,
242 Zeta2(ZetaVersion),
243 Teacher(ZetaVersion),
244 TeacherNonBatching(ZetaVersion),
245}
246
247impl Default for PredictionProvider {
248 fn default() -> Self {
249 PredictionProvider::Zeta2(ZetaVersion::default())
250 }
251}
252
253impl std::fmt::Display for PredictionProvider {
254 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
255 match self {
256 PredictionProvider::Sweep => write!(f, "sweep"),
257 PredictionProvider::Mercury => write!(f, "mercury"),
258 PredictionProvider::Zeta1 => write!(f, "zeta1"),
259 PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
260 PredictionProvider::Teacher(version) => write!(f, "teacher:{version}"),
261 PredictionProvider::TeacherNonBatching(version) => {
262 write!(f, "teacher-non-batching:{version}")
263 }
264 }
265 }
266}
267
268impl std::str::FromStr for PredictionProvider {
269 type Err = anyhow::Error;
270
271 fn from_str(mut s: &str) -> Result<Self, Self::Err> {
272 let mut version = ZetaVersion::default();
273 if let Some((first, second)) = s.split_once(':') {
274 version = ZetaVersion::parse(second)?;
275 s = first;
276 }
277
278 let s_lower = s.to_lowercase();
279 match s_lower.as_str() {
280 "sweep" => Ok(PredictionProvider::Sweep),
281 "mercury" => Ok(PredictionProvider::Mercury),
282 "zeta1" => Ok(PredictionProvider::Zeta1),
283 "zeta2" => Ok(PredictionProvider::Zeta2(version)),
284 "teacher" => Ok(PredictionProvider::Teacher(version)),
285 "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
286 Ok(PredictionProvider::TeacherNonBatching(version))
287 }
288 _ => {
289 anyhow::bail!(
290 "unknown provider `{s}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher-non-batching\n\
291 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
292 Available zeta versions:\n{}",
293 ZetaVersion::options_as_string()
294 )
295 }
296 }
297 }
298}
299
300impl Serialize for PredictionProvider {
301 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
302 where
303 S: Serializer,
304 {
305 serializer.serialize_str(&self.to_string())
306 }
307}
308
309impl<'de> Deserialize<'de> for PredictionProvider {
310 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
311 where
312 D: Deserializer<'de>,
313 {
314 let s = String::deserialize(deserializer)?;
315 s.parse().map_err(serde::de::Error::custom)
316 }
317}
318
319#[derive(Debug, Args, Clone)]
320struct SynthesizeArgs {
321 /// Repository URLs (git@github.com:owner/repo or https://...)
322 #[clap(long, required = true, num_args = 1..)]
323 repos: Vec<String>,
324
325 /// Number of examples to generate per repository
326 #[clap(long, default_value_t = 5)]
327 count: usize,
328
329 /// Maximum commits to scan per repository before giving up
330 #[clap(long, default_value_t = 100)]
331 max_commits: usize,
332
333 /// Ignore state file and reprocess all commits
334 #[clap(long)]
335 fresh: bool,
336}
337
338#[derive(Debug, Args, Clone)]
339struct ImportBatchArgs {
340 /// Anthropic batch IDs to import (e.g., msgbatch_xxx)
341 #[clap(long, required = true, num_args = 1..)]
342 batch_ids: Vec<String>,
343}
344
345impl EpArgs {
346 fn output_path(&self) -> Option<PathBuf> {
347 if self.in_place {
348 if self.inputs.len() == 1 {
349 self.inputs.first().cloned()
350 } else {
351 panic!("--in-place requires exactly one input file")
352 }
353 } else {
354 self.output.clone()
355 }
356 }
357}
358
359async fn load_examples(
360 http_client: Arc<dyn http_client::HttpClient>,
361 args: &EpArgs,
362 output_path: Option<&PathBuf>,
363 background_executor: BackgroundExecutor,
364) -> anyhow::Result<Vec<Example>> {
365 let mut captured_after_timestamps = Vec::new();
366 let mut rejected_after_timestamps = Vec::new();
367 let mut file_inputs = Vec::new();
368
369 for input in &args.inputs {
370 let input_string = input.to_string_lossy();
371 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
372 captured_after_timestamps.push(timestamp.to_string());
373 } else if let Some(timestamp) =
374 pull_examples::parse_rejected_after_input(input_string.as_ref())
375 {
376 rejected_after_timestamps.push(timestamp.to_string());
377 } else {
378 file_inputs.push(input.clone());
379 }
380 }
381
382 let mut examples = read_example_files(&file_inputs);
383
384 Progress::global().set_total_examples(examples.len());
385
386 let remaining_limit_for_snowflake =
387 args.limit.map(|limit| limit.saturating_sub(examples.len()));
388
389 if let Some(0) = remaining_limit_for_snowflake {
390 log::info!(
391 "skipping Snowflake inputs because --limit is already satisfied by example files"
392 );
393 } else {
394 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
395
396 if !captured_after_timestamps.is_empty() {
397 captured_after_timestamps.sort();
398
399 let mut captured_examples = pull_examples::fetch_captured_examples_after(
400 http_client.clone(),
401 &captured_after_timestamps,
402 max_rows_per_timestamp,
403 background_executor.clone(),
404 )
405 .await?;
406 examples.append(&mut captured_examples);
407 }
408
409 if !rejected_after_timestamps.is_empty() {
410 rejected_after_timestamps.sort();
411
412 let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
413 http_client,
414 &rejected_after_timestamps,
415 max_rows_per_timestamp,
416 background_executor,
417 )
418 .await?;
419 examples.append(&mut rejected_examples);
420 }
421 }
422
423 crate::example::sort_examples_by_repo_and_rev(&mut examples);
424
425 if let Some(name_filter) = &args.name {
426 examples.retain(|example| example.spec.name.contains(name_filter));
427 }
428 if let Some(repo_filter) = &args.repo {
429 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
430 }
431
432 // Skip resume logic for --in-place since input and output are the same file,
433 // which would incorrectly treat all input examples as already processed.
434 if !args.in_place {
435 if let Some(path) = output_path {
436 resume_from_output(path, &mut examples);
437 }
438 }
439
440 if let Some(offset) = args.offset {
441 examples.splice(0..offset, []);
442 }
443
444 if let Some(limit) = args.limit {
445 examples.truncate(limit);
446 }
447
448 let progress = Progress::global();
449 progress.set_total_examples(examples.len());
450 progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
451
452 Ok(examples)
453}
454
455fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
456 let mut hasher = collections::FxHasher::default();
457 spec.hash(&mut hasher);
458 hasher.finish()
459}
460
461fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
462 let file = match File::open(path) {
463 Ok(f) => f,
464 Err(_) => return,
465 };
466
467 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
468
469 let reader = BufReader::new(file);
470 let mut kept_lines = Vec::new();
471 let mut kept_hashes = HashSet::default();
472
473 for line in reader.lines() {
474 let line = match line {
475 Ok(l) => l,
476 Err(_) => continue,
477 };
478
479 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
480 let hash = spec_hash(&output_example.spec);
481 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
482 kept_hashes.insert(hash);
483 kept_lines.push(line);
484 }
485 }
486 }
487
488 let total = examples.len();
489 let already_processed = kept_hashes.len();
490
491 eprintln!(
492 "Resuming: {}/{} examples already processed",
493 already_processed, total
494 );
495
496 let file = OpenOptions::new()
497 .write(true)
498 .truncate(true)
499 .open(path)
500 .expect("Failed to open output file for rewriting");
501 let mut writer = BufWriter::new(file);
502 for line in &kept_lines {
503 writeln!(writer, "{}", line).expect("Failed to write to output file");
504 }
505 writer.flush().expect("Failed to flush output file");
506
507 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
508}
509
510fn main() {
511 let args = EpArgs::parse();
512
513 if args.printenv {
514 ::util::shell_env::print_env();
515 return;
516 }
517
518 let output = args.output_path();
519 let command = match &args.command {
520 Some(cmd) => cmd.clone(),
521 None => {
522 EpArgs::command().print_help().unwrap();
523 return;
524 }
525 };
526
527 match &command {
528 Command::ImportBatch(import_args) => {
529 smol::block_on(async {
530 let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
531 .expect("Failed to create Anthropic client");
532 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
533 eprintln!("Error importing batches: {:?}", e);
534 std::process::exit(1);
535 }
536 println!(
537 "Successfully imported {} batch(es)",
538 import_args.batch_ids.len()
539 );
540 });
541 return;
542 }
543 Command::Clean => {
544 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
545 return;
546 }
547 Command::Synthesize(synth_args) => {
548 let Some(output_dir) = args.output else {
549 panic!("output dir is required");
550 };
551 let config = SynthesizeConfig {
552 repo_urls: synth_args.repos.clone(),
553 count: synth_args.count,
554 max_commits: synth_args.max_commits,
555 output_dir,
556 fresh: synth_args.fresh,
557 };
558 smol::block_on(async {
559 if let Err(e) = run_synthesize(config).await {
560 eprintln!("Error: {:?}", e);
561 std::process::exit(1);
562 }
563 });
564 return;
565 }
566 Command::SplitCommit(split_commit_args) => {
567 if let Err(error) = split_commit::run_split_commit(
568 split_commit_args,
569 &args.inputs,
570 output.as_ref(),
571 args.failed,
572 ) {
573 eprintln!("{error:#}");
574 std::process::exit(1);
575 }
576 return;
577 }
578 Command::Split(split_args) => {
579 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
580 eprintln!("{error:#}");
581 std::process::exit(1);
582 }
583 return;
584 }
585 Command::FilterLanguages(filter_args) => {
586 if let Err(error) =
587 run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
588 {
589 eprintln!("{error:#}");
590 std::process::exit(1);
591 }
592 return;
593 }
594 Command::Qa(qa_args) => {
595 // Read examples from input files
596 let mut examples = example::read_example_files(&args.inputs);
597
598 // Apply filters
599 if let Some(name_filter) = &args.name {
600 examples.retain(|e| e.spec.name.contains(name_filter));
601 }
602 if let Some(repo_filter) = &args.repo {
603 examples.retain(|e| e.spec.repository_url.contains(repo_filter));
604 }
605 if let Some(offset) = args.offset {
606 examples.splice(0..offset, []);
607 }
608 if let Some(limit) = args.limit {
609 examples.truncate(limit);
610 }
611
612 smol::block_on(async {
613 if let Err(e) = qa::run_qa(&mut examples, qa_args, output.as_ref()).await {
614 eprintln!("Error: {:?}", e);
615 std::process::exit(1);
616 }
617 });
618 return;
619 }
620 _ => {}
621 }
622
623 let http_client = Arc::new(ReqwestClient::new());
624 let app = Application::headless().with_http_client(http_client);
625
626 app.run(move |cx| {
627 let app_state = Arc::new(headless::init(cx));
628 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
629
630 cx.spawn(async move |cx| {
631 let result = async {
632 let examples = load_examples(
633 app_state.client.http_client(),
634 &args,
635 output.as_ref(),
636 cx.background_executor().clone(),
637 )
638 .await?;
639
640 match &command {
641 Command::Predict(args) | Command::Score(args) => {
642 predict::sync_batches(args.provider.as_ref()).await?;
643 }
644 Command::Eval(args) => {
645 predict::sync_batches(args.predict.provider.as_ref()).await?;
646 }
647 _ => (),
648 }
649
650 let failfast_on_single_example = examples.len() == 1;
651
652 // For --in-place, write to a temp file and rename at the end to avoid data loss on interruption
653 let in_place_temp_path = if args.in_place {
654 output.as_ref().map(|path| {
655 let mut temp_path = path.clone();
656 temp_path.set_extension("jsonl.tmp");
657 temp_path
658 })
659 } else {
660 None
661 };
662
663 let output_sender: Option<mpsc::UnboundedSender<String>> =
664 if args.output.is_some() || !matches!(command, Command::Eval(_)) {
665 let write_path = in_place_temp_path.as_ref().or(output.as_ref());
666 write_path.map(|path| {
667 let file = if args.in_place {
668 // For --in-place, write to temp file (truncate if exists)
669 OpenOptions::new()
670 .create(true)
671 .write(true)
672 .truncate(true)
673 .open(path)
674 .expect("Failed to open temp output file")
675 } else {
676 // For regular output, append to support resuming
677 OpenOptions::new()
678 .create(true)
679 .append(true)
680 .open(path)
681 .expect("Failed to open output file")
682 };
683 let mut writer = BufWriter::new(file);
684 let (sender, mut receiver) = mpsc::unbounded::<String>();
685 cx.background_spawn(async move {
686 while let Some(line) = receiver.next().await {
687 writeln!(writer, "{}", line).expect("Failed to write example");
688 writer.flush().expect("Failed to flush output");
689 }
690 })
691 .detach();
692 sender
693 })
694 } else {
695 None
696 };
697
698 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
699 let finished_examples = Mutex::new(Vec::new());
700
701 let mut tasks = Vec::new();
702 for _ in 0..args.max_parallelism {
703 tasks.push(async {
704 loop {
705 let Some(mut repo_examples) =
706 grouped_examples.lock().unwrap().pop_front()
707 else {
708 break;
709 };
710 for example in &mut repo_examples {
711 let example_progress =
712 Progress::global().start_group(&example.spec.name);
713
714 let result = async {
715 match &command {
716 Command::Read => {}
717 Command::LoadProject => {
718 run_load_project(
719 example,
720 app_state.clone(),
721 &example_progress,
722 cx.clone(),
723 )
724 .await?;
725 }
726 Command::Context => {
727 run_context_retrieval(
728 example,
729 app_state.clone(),
730 &example_progress,
731 cx.clone(),
732 )
733 .await?;
734 }
735 Command::FormatPrompt(args) => {
736 run_format_prompt(
737 example,
738 args,
739 app_state.clone(),
740 &example_progress,
741 cx.clone(),
742 )
743 .await?;
744 }
745 Command::Predict(args) => {
746 run_prediction(
747 example,
748 args,
749 app_state.clone(),
750 &example_progress,
751 cx.clone(),
752 )
753 .await?;
754 }
755 Command::ParseOutput => {
756 parse_output::run_parse_output(example)?;
757 }
758 Command::Distill => {
759 run_distill(example).await?;
760 }
761 Command::Score(args) => {
762 run_scoring(
763 example,
764 args,
765 app_state.clone(),
766 &example_progress,
767 cx.clone(),
768 )
769 .await?;
770 }
771 Command::Eval(args) => {
772 run_scoring(
773 example,
774 &args.predict,
775 app_state.clone(),
776 &example_progress,
777 cx.clone(),
778 )
779 .await?;
780 }
781 Command::Clean
782 | Command::Synthesize(_)
783 | Command::SplitCommit(_)
784 | Command::Split(_)
785 | Command::FilterLanguages(_)
786 | Command::ImportBatch(_)
787 | Command::Qa(_) => {
788 unreachable!()
789 }
790 }
791 anyhow::Ok(())
792 }
793 .await;
794
795 let failed = if let Err(error) = result {
796 handle_error(
797 error,
798 &args,
799 &command,
800 &app_state,
801 failfast_on_single_example,
802 &example,
803 )
804 .await;
805 true
806 } else {
807 false
808 };
809
810 let should_write = !failed || args.failed == FailedHandling::Keep;
811 if should_write {
812 if let Some(ref mut sender) = output_sender.clone() {
813 let line = serde_json::to_string(&example).unwrap();
814 sender
815 .send(line)
816 .await
817 .expect("Failed to send to output writer");
818 } else if args.output.is_none()
819 && !matches!(command, Command::Eval(_))
820 {
821 let line = serde_json::to_string(&example).unwrap();
822 println!("{}", line);
823 }
824 }
825 }
826
827 let repo_url = &repo_examples.first().unwrap().spec.repository_url;
828 let project = repo_examples
829 .iter()
830 .find_map(|e| e.state.as_ref().map(|s| s.project.clone()))
831 .or_else(|| app_state.project_cache.get(repo_url));
832
833 if let Some(project) = project {
834 let mut cx = cx.clone();
835
836 let shutdown_task: Task<()> =
837 project.update(&mut cx, |project, cx| {
838 let lsp_store = project.lsp_store();
839 lsp_store.update(cx, |lsp_store, cx| {
840 lsp_store.shutdown_all_language_servers(cx)
841 })
842 });
843
844 shutdown_task.await;
845
846 if let Some(ep_store) =
847 cx.update(|cx| EditPredictionStore::try_global(cx))
848 {
849 ep_store.update(&mut cx, |store, _| {
850 store.remove_project(&project);
851 });
852 }
853 }
854
855 app_state.project_cache.remove(repo_url);
856 for example in &mut repo_examples {
857 example.state.take();
858 }
859 finished_examples
860 .lock()
861 .unwrap()
862 .extend_from_slice(&repo_examples);
863 }
864 });
865 }
866 futures::future::join_all(tasks).await;
867
868 Progress::global().finalize();
869
870 match &command {
871 Command::Predict(args) | Command::Score(args) => {
872 predict::sync_batches(args.provider.as_ref()).await?;
873 }
874 Command::Eval(args) => {
875 predict::sync_batches(args.predict.provider.as_ref()).await?;
876 }
877 _ => (),
878 }
879
880 match &command {
881 Command::Eval(args) => {
882 let examples = finished_examples.lock().unwrap();
883 score::print_report(&examples);
884 if let Some(summary_path) = &args.summary_json {
885 score::write_summary_json(&examples, summary_path)?;
886 }
887 }
888 _ => (),
889 };
890
891 // For --in-place, atomically rename temp file to original
892 if let (Some(temp_path), Some(final_path)) = (&in_place_temp_path, &output) {
893 std::fs::rename(temp_path, final_path)
894 .expect("Failed to rename temp file to final output");
895 }
896
897 anyhow::Ok(())
898 }
899 .await;
900
901 if let Err(e) = result {
902 panic!("Fatal error: {:?}", e);
903 }
904
905 let _ = cx.update(|cx| cx.quit());
906 })
907 .detach();
908 });
909}
910
911async fn handle_error(
912 error: anyhow::Error,
913 args: &EpArgs,
914 command: &Command,
915 app_state: &Arc<headless::EpAppState>,
916 failfast_on_single_example: bool,
917 example: &Example,
918) {
919 Progress::global().increment_failed();
920
921 let msg;
922 if !matches!(args.failed, FailedHandling::SkipNoFiles) {
923 let example_name = example.spec.filename();
924
925 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
926 app_state
927 .fs
928 .write(
929 &failed_example_path,
930 &serde_json::to_vec_pretty(&example).unwrap(),
931 )
932 .await
933 .unwrap();
934 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
935 app_state
936 .fs
937 .write(&err_path, format!("{error:?}").as_bytes())
938 .await
939 .unwrap();
940
941 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
942 let mut file = OpenOptions::new()
943 .create(true)
944 .append(true)
945 .open(&failed_jsonl_path)
946 .expect("Failed to open failed.jsonl");
947 writeln!(file, "{}", serde_json::to_string(example).unwrap())
948 .expect("Failed to write to failed.jsonl");
949
950 let cursor_path = example
951 .repo_name()
952 .unwrap()
953 .worktree_path()
954 .join(&example.spec.cursor_path);
955 msg = format!(
956 indoc::indoc! {"
957 While processing \"{}\":
958
959 \x1b[31m{:?}\x1b[0m
960
961 Example: \x1b[36m{}\x1b[0m
962 Error file: \x1b[36m{}\x1b[0m
963 Cursor file: \x1b[36m{}\x1b[0m
964 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
965 "},
966 example.spec.name,
967 error,
968 failed_example_path.display(),
969 err_path.display(),
970 cursor_path.display(),
971 command,
972 failed_example_path.display(),
973 );
974 } else {
975 msg = format!(
976 indoc::indoc! {"
977 While processing \"{}\":
978
979 \x1b[31m{:?}\x1b[0m
980 "},
981 example.spec.name, error
982 );
983 }
984
985 if args.failfast || failfast_on_single_example {
986 Progress::global().finalize();
987 panic!("{}", msg);
988 } else {
989 log::error!("{}", msg);
990 }
991}