main.rs

  1mod anthropic_client;
  2mod distill;
  3mod example;
  4mod format_prompt;
  5mod git;
  6mod headless;
  7mod load_project;
  8mod metrics;
  9mod paths;
 10mod predict;
 11mod progress;
 12mod pull_examples;
 13mod reorder_patch;
 14mod retrieve_context;
 15mod score;
 16mod split_commit;
 17mod split_dataset;
 18mod synthesize;
 19use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
 20use collections::HashSet;
 21use edit_prediction::EditPredictionStore;
 22use futures::channel::mpsc;
 23use futures::{SinkExt as _, StreamExt as _};
 24use gpui::{AppContext as _, Application, BackgroundExecutor};
 25use zeta_prompt::ZetaVersion;
 26
 27use reqwest_client::ReqwestClient;
 28use serde::{Deserialize, Deserializer, Serialize, Serializer};
 29use std::fmt::Display;
 30use std::fs::{File, OpenOptions};
 31use std::hash::{Hash, Hasher};
 32use std::io::{BufRead, BufReader, BufWriter, Write};
 33use std::sync::Mutex;
 34use std::{path::PathBuf, sync::Arc};
 35
 36use crate::distill::run_distill;
 37use crate::example::{Example, group_examples_by_repo, read_example_files};
 38use crate::format_prompt::run_format_prompt;
 39use crate::load_project::run_load_project;
 40use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
 41use crate::predict::run_prediction;
 42use crate::progress::Progress;
 43use crate::retrieve_context::run_context_retrieval;
 44use crate::score::run_scoring;
 45use crate::split_commit::SplitCommitArgs;
 46use crate::split_dataset::SplitArgs;
 47use crate::synthesize::{SynthesizeConfig, run_synthesize};
 48
 49#[derive(Parser, Debug)]
 50#[command(name = "ep")]
 51struct EpArgs {
 52    #[arg(long, default_value_t = false)]
 53    printenv: bool,
 54    #[clap(long, default_value_t = 10, global = true)]
 55    max_parallelism: usize,
 56    #[clap(long, global = true)]
 57    limit: Option<usize>,
 58    /// Filter examples by name
 59    #[clap(long, global = true)]
 60    name: Option<String>,
 61    /// Filter examples by repository
 62    #[clap(long, global = true)]
 63    repo: Option<String>,
 64    #[command(subcommand)]
 65    command: Option<Command>,
 66    #[clap(global = true, help = INPUTS_HELP)]
 67    inputs: Vec<PathBuf>,
 68    #[arg(long, short, global = true)]
 69    output: Option<PathBuf>,
 70    #[arg(long, short, global = true)]
 71    in_place: bool,
 72    #[arg(long, short, global = true)]
 73    failfast: bool,
 74    /// How to handle failed examples in output: keep them or skip them.
 75    /// Failed examples are always logged to the run's failed directory.
 76    #[arg(long, global = true, default_value = "keep")]
 77    failed: FailedHandling,
 78}
 79
 80/// Controls whether failed examples are included in the main output.
 81/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
 82#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
 83pub enum FailedHandling {
 84    /// Include failed examples in the main output (default)
 85    #[default]
 86    Keep,
 87    /// Exclude failed examples from the main output
 88    Skip,
 89}
 90
 91const INPUTS_HELP: &str = r#"
 92Inputs can be file paths or special specifiers:
 93
 94  path
 95      Path to an example(s) file (.md, .json, or .jsonl)
 96
 97  captured-after:{timestamp}
 98      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 99
100      You can specify this multiple times and mix it with file inputs.
101
102      Required environment variables to connect to Snowflake:
103          EP_SNOWFLAKE_API_KEY
104          EP_SNOWFLAKE_BASE_URL
105
106      Optional:
107          EP_SNOWFLAKE_ROLE
108
109Examples:
110
111  # Predict from a file
112  ep predict examples.jsonl
113
114  # Predict from captured examples after a timestamp
115  ep predict captured-after:2025-01-01T00:00:00Z
116
117  # Mix file inputs and captured-after in the same invocation
118  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
119"#;
120
121#[derive(Subcommand, Debug, Clone)]
122enum Command {
123    /// Parse markdown examples and output a combined .jsonl file
124    ParseExample,
125    /// Create git worktrees for each example and load file contents
126    LoadProject,
127    /// Retrieve context for input examples.
128    Context,
129    /// Generate a prompt string for a specific model
130    FormatPrompt(FormatPromptArgs),
131    /// Runs edit prediction
132    Predict(PredictArgs),
133    /// Computes a score based on actual and expected patches
134    Score(PredictArgs),
135    /// Prepares a distillation dataset by copying expected outputs to
136    /// predicted outputs and removing actual outputs and prompts.
137    Distill,
138    /// Print aggregated scores
139    Eval(PredictArgs),
140    /// Generate eval examples by analyzing git commits from a repository
141    Synthesize(SynthesizeArgs),
142    /// Remove git repositories and worktrees
143    Clean,
144    /// Generate an evaluation example by splitting a chronologically-ordered commit
145    SplitCommit(SplitCommitArgs),
146    /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
147    Split(SplitArgs),
148}
149
150impl Display for Command {
151    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152        match self {
153            Command::ParseExample => write!(f, "parse-example"),
154            Command::LoadProject => write!(f, "load-project"),
155            Command::Context => write!(f, "context"),
156            Command::FormatPrompt(args) => {
157                write!(f, "format-prompt --provider={}", args.provider)
158            }
159            Command::Predict(args) => {
160                write!(f, "predict --provider={}", args.provider)
161            }
162            Command::Score(args) => {
163                write!(f, "score --provider={}", args.provider)
164            }
165            Command::Distill => write!(f, "distill"),
166            Command::Eval(args) => {
167                write!(f, "eval --provider={}", args.provider)
168            }
169            Command::Synthesize(args) => {
170                write!(f, "synthesize --repos {}", args.repos.join(" "))
171            }
172            Command::Clean => write!(f, "clean"),
173            Command::SplitCommit(_) => write!(f, "split-commit"),
174            Command::Split(_) => write!(f, "split"),
175        }
176    }
177}
178
179#[derive(Debug, Args, Clone)]
180struct FormatPromptArgs {
181    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
182    provider: PredictionProvider,
183}
184
185#[derive(Debug, Args, Clone)]
186struct PredictArgs {
187    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
188    provider: PredictionProvider,
189    #[clap(long, default_value_t = 1)]
190    repetitions: usize,
191}
192
193#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
194enum PredictionProvider {
195    Sweep,
196    Mercury,
197    Zeta1,
198    Zeta2(ZetaVersion),
199    Teacher(ZetaVersion),
200    TeacherNonBatching(ZetaVersion),
201}
202
203impl Default for PredictionProvider {
204    fn default() -> Self {
205        PredictionProvider::Zeta2(ZetaVersion::default())
206    }
207}
208
209impl std::fmt::Display for PredictionProvider {
210    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
211        match self {
212            PredictionProvider::Sweep => write!(f, "sweep"),
213            PredictionProvider::Mercury => write!(f, "mercury"),
214            PredictionProvider::Zeta1 => write!(f, "zeta1"),
215            PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
216            PredictionProvider::Teacher(version) => write!(f, "teacher:{version}"),
217            PredictionProvider::TeacherNonBatching(version) => {
218                write!(f, "teacher-non-batching:{version}")
219            }
220        }
221    }
222}
223
224impl std::str::FromStr for PredictionProvider {
225    type Err = anyhow::Error;
226
227    fn from_str(mut s: &str) -> Result<Self, Self::Err> {
228        let mut version = ZetaVersion::default();
229        if let Some((first, second)) = s.split_once(':') {
230            version = ZetaVersion::parse(second)?;
231            s = first;
232        }
233
234        let s_lower = s.to_lowercase();
235        match s_lower.as_str() {
236            "sweep" => Ok(PredictionProvider::Sweep),
237            "mercury" => Ok(PredictionProvider::Mercury),
238            "zeta1" => Ok(PredictionProvider::Zeta1),
239            "zeta2" => Ok(PredictionProvider::Zeta2(version)),
240            "teacher" => Ok(PredictionProvider::Teacher(version)),
241            "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
242                Ok(PredictionProvider::TeacherNonBatching(version))
243            }
244            _ => {
245                anyhow::bail!(
246                    "unknown provider `{s}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher-non-batching\n\
247                 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
248                 Available zeta versions:\n{}",
249                    ZetaVersion::options_as_string()
250                )
251            }
252        }
253    }
254}
255
256impl Serialize for PredictionProvider {
257    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
258    where
259        S: Serializer,
260    {
261        serializer.serialize_str(&self.to_string())
262    }
263}
264
265impl<'de> Deserialize<'de> for PredictionProvider {
266    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
267    where
268        D: Deserializer<'de>,
269    {
270        let s = String::deserialize(deserializer)?;
271        s.parse().map_err(serde::de::Error::custom)
272    }
273}
274
275#[derive(Debug, Args, Clone)]
276struct SynthesizeArgs {
277    /// Repository URLs (git@github.com:owner/repo or https://...)
278    #[clap(long, required = true, num_args = 1..)]
279    repos: Vec<String>,
280
281    /// Number of examples to generate per repository
282    #[clap(long, default_value_t = 5)]
283    count: usize,
284
285    /// Maximum commits to scan per repository before giving up
286    #[clap(long, default_value_t = 100)]
287    max_commits: usize,
288
289    /// Ignore state file and reprocess all commits
290    #[clap(long)]
291    fresh: bool,
292}
293
294impl EpArgs {
295    fn output_path(&self) -> Option<PathBuf> {
296        if self.in_place {
297            if self.inputs.len() == 1 {
298                self.inputs.first().cloned()
299            } else {
300                panic!("--in-place requires exactly one input file")
301            }
302        } else {
303            self.output.clone()
304        }
305    }
306}
307
308async fn load_examples(
309    http_client: Arc<dyn http_client::HttpClient>,
310    args: &EpArgs,
311    output_path: Option<&PathBuf>,
312    background_executor: BackgroundExecutor,
313) -> anyhow::Result<Vec<Example>> {
314    let mut captured_after_timestamps = Vec::new();
315    let mut file_inputs = Vec::new();
316
317    for input in &args.inputs {
318        let input_string = input.to_string_lossy();
319        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
320            captured_after_timestamps.push(timestamp.to_string());
321        } else {
322            file_inputs.push(input.clone());
323        }
324    }
325
326    let mut examples = read_example_files(&file_inputs);
327
328    Progress::global().set_total_examples(examples.len());
329
330    let remaining_limit_for_snowflake =
331        args.limit.map(|limit| limit.saturating_sub(examples.len()));
332
333    if let Some(0) = remaining_limit_for_snowflake {
334        log::info!(
335            "skipping captured-after inputs because --limit is already satisfied by example files"
336        );
337    } else if !captured_after_timestamps.is_empty() {
338        captured_after_timestamps.sort();
339
340        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
341
342        let mut captured_examples = pull_examples::fetch_captured_examples_after(
343            http_client,
344            &captured_after_timestamps,
345            max_rows_per_timestamp,
346            background_executor,
347        )
348        .await?;
349        examples.append(&mut captured_examples);
350    }
351
352    crate::example::sort_examples_by_repo_and_rev(&mut examples);
353
354    if let Some(name_filter) = &args.name {
355        examples.retain(|example| example.spec.name.contains(name_filter));
356    }
357    if let Some(repo_filter) = &args.repo {
358        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
359    }
360
361    if let Some(limit) = args.limit {
362        if examples.len() > limit {
363            examples.truncate(limit);
364        }
365    }
366
367    if let Some(path) = output_path {
368        resume_from_output(path, &mut examples);
369    }
370
371    Progress::global().set_total_examples(examples.len());
372
373    Ok(examples)
374}
375
376fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
377    let mut hasher = collections::FxHasher::default();
378    spec.hash(&mut hasher);
379    hasher.finish()
380}
381
382fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
383    let file = match File::open(path) {
384        Ok(f) => f,
385        Err(_) => return,
386    };
387
388    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
389
390    let reader = BufReader::new(file);
391    let mut kept_lines = Vec::new();
392    let mut kept_hashes = HashSet::default();
393
394    for line in reader.lines() {
395        let line = match line {
396            Ok(l) => l,
397            Err(_) => continue,
398        };
399
400        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
401            let hash = spec_hash(&output_example.spec);
402            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
403                kept_hashes.insert(hash);
404                kept_lines.push(line);
405            }
406        }
407    }
408
409    let total = examples.len();
410    let already_processed = kept_hashes.len();
411
412    eprintln!(
413        "Resuming: {}/{} examples already processed",
414        already_processed, total
415    );
416
417    let file = OpenOptions::new()
418        .write(true)
419        .truncate(true)
420        .open(path)
421        .expect("Failed to open output file for rewriting");
422    let mut writer = BufWriter::new(file);
423    for line in &kept_lines {
424        writeln!(writer, "{}", line).expect("Failed to write to output file");
425    }
426    writer.flush().expect("Failed to flush output file");
427
428    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
429}
430
431fn main() {
432    let args = EpArgs::parse();
433
434    if args.printenv {
435        ::util::shell_env::print_env();
436        return;
437    }
438
439    let output = args.output_path();
440    let command = match &args.command {
441        Some(cmd) => cmd.clone(),
442        None => {
443            EpArgs::command().print_help().unwrap();
444            return;
445        }
446    };
447
448    match &command {
449        Command::Clean => {
450            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
451            return;
452        }
453        Command::Synthesize(synth_args) => {
454            let Some(output_dir) = args.output else {
455                panic!("output dir is required");
456            };
457            let config = SynthesizeConfig {
458                repo_urls: synth_args.repos.clone(),
459                count: synth_args.count,
460                max_commits: synth_args.max_commits,
461                output_dir,
462                fresh: synth_args.fresh,
463            };
464            smol::block_on(async {
465                if let Err(e) = run_synthesize(config).await {
466                    eprintln!("Error: {:?}", e);
467                    std::process::exit(1);
468                }
469            });
470            return;
471        }
472        Command::SplitCommit(split_commit_args) => {
473            if let Err(error) = split_commit::run_split_commit(
474                split_commit_args,
475                &args.inputs,
476                output.as_ref(),
477                args.failed,
478            ) {
479                eprintln!("{error:#}");
480                std::process::exit(1);
481            }
482            return;
483        }
484        Command::Split(split_args) => {
485            if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
486                eprintln!("{error:#}");
487                std::process::exit(1);
488            }
489            return;
490        }
491        _ => {}
492    }
493
494    let http_client = Arc::new(ReqwestClient::new());
495    let app = Application::headless().with_http_client(http_client);
496
497    app.run(move |cx| {
498        let app_state = Arc::new(headless::init(cx));
499        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
500
501        cx.spawn(async move |cx| {
502            let result = async {
503                let examples = load_examples(
504                    app_state.client.http_client(),
505                    &args,
506                    output.as_ref(),
507                    cx.background_executor().clone(),
508                )
509                .await?;
510
511                match &command {
512                    Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
513                        predict::sync_batches(&args.provider).await?;
514                    }
515                    _ => (),
516                }
517
518                let failfast_on_single_example = examples.len() == 1;
519
520                let output_sender: Option<mpsc::UnboundedSender<String>> =
521                    if args.output.is_some() || !matches!(command, Command::Eval(_)) {
522                        output.as_ref().map(|path| {
523                            let file = OpenOptions::new()
524                                .create(true)
525                                .append(true)
526                                .open(path)
527                                .expect("Failed to open output file");
528                            let mut writer = BufWriter::new(file);
529                            let (sender, mut receiver) = mpsc::unbounded::<String>();
530                            cx.background_spawn(async move {
531                                while let Some(line) = receiver.next().await {
532                                    writeln!(writer, "{}", line).expect("Failed to write example");
533                                    writer.flush().expect("Failed to flush output");
534                                }
535                            })
536                            .detach();
537                            sender
538                        })
539                    } else {
540                        None
541                    };
542
543                let grouped_examples = Mutex::new(group_examples_by_repo(examples));
544                let finished_examples = Mutex::new(Vec::new());
545
546                let mut tasks = Vec::new();
547                for _ in 0..args.max_parallelism {
548                    tasks.push(async {
549                        loop {
550                            let Some(mut repo_examples) =
551                                grouped_examples.lock().unwrap().pop_front()
552                            else {
553                                break;
554                            };
555                            for example in &mut repo_examples {
556                                let example_progress =
557                                    Progress::global().start_group(&example.spec.name);
558
559                                let result = async {
560                                    match &command {
561                                        Command::ParseExample => {}
562                                        Command::LoadProject => {
563                                            run_load_project(
564                                                example,
565                                                app_state.clone(),
566                                                &example_progress,
567                                                cx.clone(),
568                                            )
569                                            .await?;
570                                        }
571                                        Command::Context => {
572                                            run_context_retrieval(
573                                                example,
574                                                app_state.clone(),
575                                                &example_progress,
576                                                cx.clone(),
577                                            )
578                                            .await?;
579                                        }
580                                        Command::FormatPrompt(args) => {
581                                            run_format_prompt(
582                                                example,
583                                                args,
584                                                app_state.clone(),
585                                                &example_progress,
586                                                cx.clone(),
587                                            )
588                                            .await?;
589                                        }
590                                        Command::Predict(args) => {
591                                            run_prediction(
592                                                example,
593                                                args,
594                                                app_state.clone(),
595                                                &example_progress,
596                                                cx.clone(),
597                                            )
598                                            .await?;
599                                        }
600                                        Command::Distill => {
601                                            run_distill(example).await?;
602                                        }
603                                        Command::Score(args) | Command::Eval(args) => {
604                                            run_scoring(
605                                                example,
606                                                &args,
607                                                app_state.clone(),
608                                                &example_progress,
609                                                cx.clone(),
610                                            )
611                                            .await?;
612                                        }
613                                        Command::Clean
614                                        | Command::Synthesize(_)
615                                        | Command::SplitCommit(_)
616                                        | Command::Split(_) => {
617                                            unreachable!()
618                                        }
619                                    }
620                                    anyhow::Ok(())
621                                }
622                                .await;
623
624                                let failed = if let Err(error) = result {
625                                    handle_error(
626                                        error,
627                                        &args,
628                                        &command,
629                                        &app_state,
630                                        failfast_on_single_example,
631                                        &example,
632                                    )
633                                    .await;
634                                    true
635                                } else {
636                                    false
637                                };
638
639                                let should_write = !failed || args.failed == FailedHandling::Keep;
640                                if should_write {
641                                    if let Some(ref mut sender) = output_sender.clone() {
642                                        let line = serde_json::to_string(&example).unwrap();
643                                        sender
644                                            .send(line)
645                                            .await
646                                            .expect("Failed to send to output writer");
647                                    } else if args.output.is_none()
648                                        && !matches!(command, Command::Eval(_))
649                                    {
650                                        let line = serde_json::to_string(&example).unwrap();
651                                        println!("{}", line);
652                                    }
653                                }
654                            }
655
656                            if let Some(state) =
657                                repo_examples.first().and_then(|e| e.state.as_ref())
658                            {
659                                let mut cx = cx.clone();
660                                if let Some(ep_store) =
661                                    cx.update(|cx| EditPredictionStore::try_global(cx))
662                                {
663                                    let project = state.project.clone();
664                                    ep_store.update(&mut cx, |store, _| {
665                                        store.remove_project(&project);
666                                    });
667                                }
668                            }
669
670                            app_state
671                                .project_cache
672                                .remove(&repo_examples.first().unwrap().spec.repository_url);
673                            for example in &mut repo_examples {
674                                example.state.take();
675                            }
676                            finished_examples
677                                .lock()
678                                .unwrap()
679                                .extend_from_slice(&repo_examples);
680                        }
681                    });
682                }
683                futures::future::join_all(tasks).await;
684
685                Progress::global().finalize();
686
687                match &command {
688                    Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
689                        predict::sync_batches(&args.provider).await?;
690                    }
691                    _ => (),
692                }
693
694                match &command {
695                    Command::Eval(_) => score::print_report(&finished_examples.lock().unwrap()),
696                    _ => (),
697                };
698
699                anyhow::Ok(())
700            }
701            .await;
702
703            if let Err(e) = result {
704                panic!("Fatal error: {:?}", e);
705            }
706
707            let _ = cx.update(|cx| cx.quit());
708        })
709        .detach();
710    });
711}
712
713async fn handle_error(
714    error: anyhow::Error,
715    args: &EpArgs,
716    command: &Command,
717    app_state: &Arc<headless::EpAppState>,
718    failfast_on_single_example: bool,
719    example: &Example,
720) {
721    Progress::global().increment_failed();
722    let example_name = example.spec.filename();
723    let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
724    app_state
725        .fs
726        .write(
727            &failed_example_path,
728            &serde_json::to_vec_pretty(&example).unwrap(),
729        )
730        .await
731        .unwrap();
732    let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
733    app_state
734        .fs
735        .write(&err_path, format!("{error:?}").as_bytes())
736        .await
737        .unwrap();
738
739    let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
740    let mut file = OpenOptions::new()
741        .create(true)
742        .append(true)
743        .open(&failed_jsonl_path)
744        .expect("Failed to open failed.jsonl");
745    writeln!(file, "{}", serde_json::to_string(example).unwrap())
746        .expect("Failed to write to failed.jsonl");
747
748    let cursor_path = example
749        .repo_name()
750        .unwrap()
751        .worktree_path()
752        .join(&example.spec.cursor_path);
753
754    let msg = format!(
755        indoc::indoc! {"
756            While processing \"{}\":
757
758            \x1b[31m{:?}\x1b[0m
759
760            Example:        \x1b[36m{}\x1b[0m
761            Error file:     \x1b[36m{}\x1b[0m
762            Cursor file:    \x1b[36m{}\x1b[0m
763            Re-run:         cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
764        "},
765        example.spec.name,
766        error,
767        failed_example_path.display(),
768        err_path.display(),
769        cursor_path.display(),
770        command,
771        failed_example_path.display(),
772    );
773    if args.failfast || failfast_on_single_example {
774        Progress::global().finalize();
775        panic!("{}", msg);
776    } else {
777        log::error!("{}", msg);
778    }
779}