main.rs

  1mod anthropic_client;
  2mod distill;
  3mod example;
  4mod format_prompt;
  5mod git;
  6mod headless;
  7mod load_project;
  8mod metrics;
  9mod paths;
 10mod predict;
 11mod progress;
 12mod pull_examples;
 13mod reorder_patch;
 14mod retrieve_context;
 15mod score;
 16mod split_commit;
 17mod split_dataset;
 18mod synthesize;
 19use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
 20use collections::HashSet;
 21use edit_prediction::EditPredictionStore;
 22use futures::channel::mpsc;
 23use futures::{SinkExt as _, StreamExt as _};
 24use gpui::{AppContext as _, Application, BackgroundExecutor};
 25use zeta_prompt::ZetaVersion;
 26
 27use reqwest_client::ReqwestClient;
 28use serde::{Deserialize, Deserializer, Serialize, Serializer};
 29use std::fmt::Display;
 30use std::fs::{File, OpenOptions};
 31use std::hash::{Hash, Hasher};
 32use std::io::{BufRead, BufReader, BufWriter, Write};
 33use std::sync::Mutex;
 34use std::{path::PathBuf, sync::Arc};
 35
 36use crate::distill::run_distill;
 37use crate::example::{Example, group_examples_by_repo, read_example_files};
 38use crate::format_prompt::run_format_prompt;
 39use crate::load_project::run_load_project;
 40use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
 41use crate::predict::run_prediction;
 42use crate::progress::Progress;
 43use crate::retrieve_context::run_context_retrieval;
 44use crate::score::run_scoring;
 45use crate::split_commit::SplitCommitArgs;
 46use crate::split_dataset::SplitArgs;
 47use crate::synthesize::{SynthesizeConfig, run_synthesize};
 48
 49#[derive(Parser, Debug)]
 50#[command(name = "ep")]
 51struct EpArgs {
 52    #[arg(long, default_value_t = false)]
 53    printenv: bool,
 54    #[clap(long, default_value_t = 10, global = true)]
 55    max_parallelism: usize,
 56    #[clap(long, global = true)]
 57    limit: Option<usize>,
 58    /// Filter examples by name
 59    #[clap(long, global = true)]
 60    name: Option<String>,
 61    /// Filter examples by repository
 62    #[clap(long, global = true)]
 63    repo: Option<String>,
 64    #[command(subcommand)]
 65    command: Option<Command>,
 66    #[clap(global = true, help = INPUTS_HELP)]
 67    inputs: Vec<PathBuf>,
 68    #[arg(long, short, global = true)]
 69    output: Option<PathBuf>,
 70    #[arg(long, short, global = true)]
 71    in_place: bool,
 72    #[arg(long, short, global = true)]
 73    failfast: bool,
 74    /// How to handle failed examples in output: keep them or skip them.
 75    /// Failed examples are always logged to the run's failed directory.
 76    #[arg(long, global = true, default_value = "keep")]
 77    failed: FailedHandling,
 78}
 79
 80/// Controls whether failed examples are included in the main output.
 81/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
 82#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
 83pub enum FailedHandling {
 84    /// Include failed examples in the main output (default)
 85    #[default]
 86    Keep,
 87    /// Exclude failed examples from the main output
 88    Skip,
 89}
 90
 91const INPUTS_HELP: &str = r#"
 92Inputs can be file paths or special specifiers:
 93
 94  path
 95      Path to an example(s) file (.md, .json, or .jsonl)
 96
 97  captured-after:{timestamp}
 98      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 99
100      You can specify this multiple times and mix it with file inputs.
101
102      Required environment variables to connect to Snowflake:
103          EP_SNOWFLAKE_API_KEY
104          EP_SNOWFLAKE_BASE_URL
105
106      Optional:
107          EP_SNOWFLAKE_ROLE
108
109Examples:
110
111  # Predict from a file
112  ep predict examples.jsonl
113
114  # Predict from captured examples after a timestamp
115  ep predict captured-after:2025-01-01T00:00:00Z
116
117  # Mix file inputs and captured-after in the same invocation
118  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
119"#;
120
121#[derive(Subcommand, Debug, Clone)]
122enum Command {
123    /// Parse markdown examples and output a combined .jsonl file
124    ParseExample,
125    /// Create git worktrees for each example and load file contents
126    LoadProject,
127    /// Retrieve context for input examples.
128    Context,
129    /// Generate a prompt string for a specific model
130    FormatPrompt(FormatPromptArgs),
131    /// Runs edit prediction
132    Predict(PredictArgs),
133    /// Computes a score based on actual and expected patches
134    Score(PredictArgs),
135    /// Prepares a distillation dataset by copying expected outputs to
136    /// predicted outputs and removing actual outputs and prompts.
137    Distill,
138    /// Print aggregated scores
139    Eval(PredictArgs),
140    /// Generate eval examples by analyzing git commits from a repository
141    Synthesize(SynthesizeArgs),
142    /// Remove git repositories and worktrees
143    Clean,
144    /// Generate an evaluation example by splitting a chronologically-ordered commit
145    SplitCommit(SplitCommitArgs),
146    /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
147    Split(SplitArgs),
148}
149
150impl Display for Command {
151    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152        match self {
153            Command::ParseExample => write!(f, "parse-example"),
154            Command::LoadProject => write!(f, "load-project"),
155            Command::Context => write!(f, "context"),
156            Command::FormatPrompt(args) => {
157                write!(f, "format-prompt --provider={}", args.provider)
158            }
159            Command::Predict(args) => {
160                write!(f, "predict --provider={}", args.provider)
161            }
162            Command::Score(args) => {
163                write!(f, "score --provider={}", args.provider)
164            }
165            Command::Distill => write!(f, "distill"),
166            Command::Eval(args) => {
167                write!(f, "eval --provider={}", args.provider)
168            }
169            Command::Synthesize(args) => {
170                write!(f, "synthesize --repos {}", args.repos.join(" "))
171            }
172            Command::Clean => write!(f, "clean"),
173            Command::SplitCommit(_) => write!(f, "split-commit"),
174            Command::Split(_) => write!(f, "split"),
175        }
176    }
177}
178
179#[derive(Debug, Args, Clone)]
180struct FormatPromptArgs {
181    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
182    provider: PredictionProvider,
183}
184
185#[derive(Debug, Args, Clone)]
186struct PredictArgs {
187    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
188    provider: PredictionProvider,
189    #[clap(long, default_value_t = 1)]
190    repetitions: usize,
191}
192
193#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
194enum PredictionProvider {
195    Sweep,
196    Mercury,
197    Zeta1,
198    Zeta2(ZetaVersion),
199    Teacher(ZetaVersion),
200    TeacherNonBatching(ZetaVersion),
201}
202
203impl Default for PredictionProvider {
204    fn default() -> Self {
205        PredictionProvider::Zeta2(ZetaVersion::default())
206    }
207}
208
209impl std::fmt::Display for PredictionProvider {
210    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
211        match self {
212            PredictionProvider::Sweep => write!(f, "sweep"),
213            PredictionProvider::Mercury => write!(f, "mercury"),
214            PredictionProvider::Zeta1 => write!(f, "zeta1"),
215            PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
216            PredictionProvider::Teacher(version) => write!(f, "teacher:{version}"),
217            PredictionProvider::TeacherNonBatching(version) => {
218                write!(f, "teacher-non-batching:{version}")
219            }
220        }
221    }
222}
223
224impl std::str::FromStr for PredictionProvider {
225    type Err = anyhow::Error;
226
227    fn from_str(mut s: &str) -> Result<Self, Self::Err> {
228        let mut version = ZetaVersion::default();
229        if let Some((first, second)) = s.split_once(':') {
230            version = ZetaVersion::parse(second)?;
231            s = first;
232        }
233
234        let s_lower = s.to_lowercase();
235        match s_lower.as_str() {
236            "sweep" => Ok(PredictionProvider::Sweep),
237            "mercury" => Ok(PredictionProvider::Mercury),
238            "zeta1" => Ok(PredictionProvider::Zeta1),
239            "zeta2" => Ok(PredictionProvider::Zeta2(version)),
240            "teacher" => Ok(PredictionProvider::Teacher(version)),
241            "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
242                Ok(PredictionProvider::TeacherNonBatching(version))
243            }
244            _ => {
245                anyhow::bail!(
246                    "unknown provider `{s}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher-non-batching\n\
247                 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
248                 Available zeta versions:\n{}",
249                    ZetaVersion::options_as_string()
250                )
251            }
252        }
253    }
254}
255
256impl Serialize for PredictionProvider {
257    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
258    where
259        S: Serializer,
260    {
261        serializer.serialize_str(&self.to_string())
262    }
263}
264
265impl<'de> Deserialize<'de> for PredictionProvider {
266    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
267    where
268        D: Deserializer<'de>,
269    {
270        let s = String::deserialize(deserializer)?;
271        s.parse().map_err(serde::de::Error::custom)
272    }
273}
274
275#[derive(Debug, Args, Clone)]
276struct SynthesizeArgs {
277    /// Repository URLs (git@github.com:owner/repo or https://...)
278    #[clap(long, required = true, num_args = 1..)]
279    repos: Vec<String>,
280
281    /// Number of examples to generate per repository
282    #[clap(long, default_value_t = 5)]
283    count: usize,
284
285    /// Maximum commits to scan per repository before giving up
286    #[clap(long, default_value_t = 100)]
287    max_commits: usize,
288
289    /// Ignore state file and reprocess all commits
290    #[clap(long)]
291    fresh: bool,
292}
293
294impl EpArgs {
295    fn output_path(&self) -> Option<PathBuf> {
296        if self.in_place {
297            if self.inputs.len() == 1 {
298                self.inputs.first().cloned()
299            } else {
300                panic!("--in-place requires exactly one input file")
301            }
302        } else {
303            self.output.clone()
304        }
305    }
306}
307
308async fn load_examples(
309    http_client: Arc<dyn http_client::HttpClient>,
310    args: &EpArgs,
311    output_path: Option<&PathBuf>,
312    background_executor: BackgroundExecutor,
313) -> anyhow::Result<Vec<Example>> {
314    let mut captured_after_timestamps = Vec::new();
315    let mut file_inputs = Vec::new();
316
317    for input in &args.inputs {
318        let input_string = input.to_string_lossy();
319        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
320            captured_after_timestamps.push(timestamp.to_string());
321        } else {
322            file_inputs.push(input.clone());
323        }
324    }
325
326    let mut examples = read_example_files(&file_inputs);
327
328    Progress::global().set_total_examples(examples.len());
329
330    let remaining_limit_for_snowflake =
331        args.limit.map(|limit| limit.saturating_sub(examples.len()));
332
333    if let Some(0) = remaining_limit_for_snowflake {
334        log::info!(
335            "skipping captured-after inputs because --limit is already satisfied by example files"
336        );
337    } else if !captured_after_timestamps.is_empty() {
338        captured_after_timestamps.sort();
339
340        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
341
342        let mut captured_examples = pull_examples::fetch_captured_examples_after(
343            http_client,
344            &captured_after_timestamps,
345            max_rows_per_timestamp,
346            background_executor,
347        )
348        .await?;
349        examples.append(&mut captured_examples);
350    }
351
352    crate::example::sort_examples_by_repo_and_rev(&mut examples);
353
354    if let Some(name_filter) = &args.name {
355        examples.retain(|example| example.spec.name.contains(name_filter));
356    }
357    if let Some(repo_filter) = &args.repo {
358        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
359    }
360
361    if let Some(limit) = args.limit {
362        if examples.len() > limit {
363            examples.truncate(limit);
364        }
365    }
366
367    if let Some(path) = output_path {
368        resume_from_output(path, &mut examples);
369    }
370
371    Progress::global().set_total_examples(examples.len());
372
373    Ok(examples)
374}
375
376fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
377    let mut hasher = collections::FxHasher::default();
378    spec.hash(&mut hasher);
379    hasher.finish()
380}
381
382fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
383    let file = match File::open(path) {
384        Ok(f) => f,
385        Err(_) => return,
386    };
387
388    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
389
390    let reader = BufReader::new(file);
391    let mut kept_lines = Vec::new();
392    let mut kept_hashes = HashSet::default();
393
394    for line in reader.lines() {
395        let line = match line {
396            Ok(l) => l,
397            Err(_) => continue,
398        };
399
400        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
401            let hash = spec_hash(&output_example.spec);
402            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
403                kept_hashes.insert(hash);
404                kept_lines.push(line);
405            }
406        }
407    }
408
409    let total = examples.len();
410    let already_processed = kept_hashes.len();
411
412    eprintln!(
413        "Resuming: {}/{} examples already processed",
414        already_processed, total
415    );
416
417    let file = OpenOptions::new()
418        .write(true)
419        .truncate(true)
420        .open(path)
421        .expect("Failed to open output file for rewriting");
422    let mut writer = BufWriter::new(file);
423    for line in &kept_lines {
424        writeln!(writer, "{}", line).expect("Failed to write to output file");
425    }
426    writer.flush().expect("Failed to flush output file");
427
428    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
429}
430
431fn main() {
432    let args = EpArgs::parse();
433
434    if args.printenv {
435        ::util::shell_env::print_env();
436        return;
437    }
438
439    let output = args.output_path();
440    let command = match &args.command {
441        Some(cmd) => cmd.clone(),
442        None => {
443            EpArgs::command().print_help().unwrap();
444            return;
445        }
446    };
447
448    match &command {
449        Command::Clean => {
450            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
451            return;
452        }
453        Command::Synthesize(synth_args) => {
454            let Some(output_dir) = args.output else {
455                panic!("output dir is required");
456            };
457            let config = SynthesizeConfig {
458                repo_urls: synth_args.repos.clone(),
459                count: synth_args.count,
460                max_commits: synth_args.max_commits,
461                output_dir,
462                fresh: synth_args.fresh,
463            };
464            smol::block_on(async {
465                if let Err(e) = run_synthesize(config).await {
466                    eprintln!("Error: {:?}", e);
467                    std::process::exit(1);
468                }
469            });
470            return;
471        }
472        Command::SplitCommit(split_commit_args) => {
473            if let Err(error) =
474                split_commit::run_split_commit(split_commit_args, &args.inputs, output.as_ref())
475            {
476                eprintln!("{error:#}");
477                std::process::exit(1);
478            }
479            return;
480        }
481        Command::Split(split_args) => {
482            if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
483                eprintln!("{error:#}");
484                std::process::exit(1);
485            }
486            return;
487        }
488        _ => {}
489    }
490
491    let http_client = Arc::new(ReqwestClient::new());
492    let app = Application::headless().with_http_client(http_client);
493
494    app.run(move |cx| {
495        let app_state = Arc::new(headless::init(cx));
496        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
497
498        cx.spawn(async move |cx| {
499            let result = async {
500                let examples = load_examples(
501                    app_state.client.http_client(),
502                    &args,
503                    output.as_ref(),
504                    cx.background_executor().clone(),
505                )
506                .await?;
507
508                match &command {
509                    Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
510                        predict::sync_batches(&args.provider).await?;
511                    }
512                    _ => (),
513                }
514
515                let failfast_on_single_example = examples.len() == 1;
516
517                let output_sender: Option<mpsc::UnboundedSender<String>> =
518                    if args.output.is_some() || !matches!(command, Command::Eval(_)) {
519                        output.as_ref().map(|path| {
520                            let file = OpenOptions::new()
521                                .create(true)
522                                .append(true)
523                                .open(path)
524                                .expect("Failed to open output file");
525                            let mut writer = BufWriter::new(file);
526                            let (sender, mut receiver) = mpsc::unbounded::<String>();
527                            cx.background_spawn(async move {
528                                while let Some(line) = receiver.next().await {
529                                    writeln!(writer, "{}", line).expect("Failed to write example");
530                                    writer.flush().expect("Failed to flush output");
531                                }
532                            })
533                            .detach();
534                            sender
535                        })
536                    } else {
537                        None
538                    };
539
540                let grouped_examples = Mutex::new(group_examples_by_repo(examples));
541                let finished_examples = Mutex::new(Vec::new());
542
543                let mut tasks = Vec::new();
544                for _ in 0..args.max_parallelism {
545                    tasks.push(async {
546                        loop {
547                            let Some(mut repo_examples) =
548                                grouped_examples.lock().unwrap().pop_front()
549                            else {
550                                break;
551                            };
552                            for example in &mut repo_examples {
553                                let example_progress =
554                                    Progress::global().start_group(&example.spec.name);
555
556                                let result = async {
557                                    match &command {
558                                        Command::ParseExample => {}
559                                        Command::LoadProject => {
560                                            run_load_project(
561                                                example,
562                                                app_state.clone(),
563                                                &example_progress,
564                                                cx.clone(),
565                                            )
566                                            .await?;
567                                        }
568                                        Command::Context => {
569                                            run_context_retrieval(
570                                                example,
571                                                app_state.clone(),
572                                                &example_progress,
573                                                cx.clone(),
574                                            )
575                                            .await?;
576                                        }
577                                        Command::FormatPrompt(args) => {
578                                            run_format_prompt(
579                                                example,
580                                                args,
581                                                app_state.clone(),
582                                                &example_progress,
583                                                cx.clone(),
584                                            )
585                                            .await?;
586                                        }
587                                        Command::Predict(args) => {
588                                            run_prediction(
589                                                example,
590                                                args,
591                                                app_state.clone(),
592                                                &example_progress,
593                                                cx.clone(),
594                                            )
595                                            .await?;
596                                        }
597                                        Command::Distill => {
598                                            run_distill(example).await?;
599                                        }
600                                        Command::Score(args) | Command::Eval(args) => {
601                                            run_scoring(
602                                                example,
603                                                &args,
604                                                app_state.clone(),
605                                                &example_progress,
606                                                cx.clone(),
607                                            )
608                                            .await?;
609                                        }
610                                        Command::Clean
611                                        | Command::Synthesize(_)
612                                        | Command::SplitCommit(_)
613                                        | Command::Split(_) => {
614                                            unreachable!()
615                                        }
616                                    }
617                                    anyhow::Ok(())
618                                }
619                                .await;
620
621                                let failed = if let Err(error) = result {
622                                    handle_error(
623                                        error,
624                                        &args,
625                                        &command,
626                                        &app_state,
627                                        failfast_on_single_example,
628                                        &example,
629                                    )
630                                    .await;
631                                    true
632                                } else {
633                                    false
634                                };
635
636                                let should_write = !failed || args.failed == FailedHandling::Keep;
637                                if should_write {
638                                    if let Some(ref mut sender) = output_sender.clone() {
639                                        let line = serde_json::to_string(&example).unwrap();
640                                        sender
641                                            .send(line)
642                                            .await
643                                            .expect("Failed to send to output writer");
644                                    } else if args.output.is_none()
645                                        && !matches!(command, Command::Eval(_))
646                                    {
647                                        let line = serde_json::to_string(&example).unwrap();
648                                        println!("{}", line);
649                                    }
650                                }
651                            }
652
653                            if let Some(state) =
654                                repo_examples.first().and_then(|e| e.state.as_ref())
655                            {
656                                let mut cx = cx.clone();
657                                if let Some(ep_store) =
658                                    cx.update(|cx| EditPredictionStore::try_global(cx))
659                                {
660                                    let project = state.project.clone();
661                                    ep_store.update(&mut cx, |store, _| {
662                                        store.remove_project(&project);
663                                    });
664                                }
665                            }
666
667                            app_state
668                                .project_cache
669                                .remove(&repo_examples.first().unwrap().spec.repository_url);
670                            for example in &mut repo_examples {
671                                example.state.take();
672                            }
673                            finished_examples
674                                .lock()
675                                .unwrap()
676                                .extend_from_slice(&repo_examples);
677                        }
678                    });
679                }
680                futures::future::join_all(tasks).await;
681
682                Progress::global().finalize();
683
684                match &command {
685                    Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
686                        predict::sync_batches(&args.provider).await?;
687                    }
688                    _ => (),
689                }
690
691                match &command {
692                    Command::Eval(_) => score::print_report(&finished_examples.lock().unwrap()),
693                    _ => (),
694                };
695
696                anyhow::Ok(())
697            }
698            .await;
699
700            if let Err(e) = result {
701                panic!("Fatal error: {:?}", e);
702            }
703
704            let _ = cx.update(|cx| cx.quit());
705        })
706        .detach();
707    });
708}
709
710async fn handle_error(
711    error: anyhow::Error,
712    args: &EpArgs,
713    command: &Command,
714    app_state: &Arc<headless::EpAppState>,
715    failfast_on_single_example: bool,
716    example: &Example,
717) {
718    Progress::global().increment_failed();
719    let example_name = example.spec.filename();
720    let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
721    app_state
722        .fs
723        .write(
724            &failed_example_path,
725            &serde_json::to_vec_pretty(&example).unwrap(),
726        )
727        .await
728        .unwrap();
729    let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
730    app_state
731        .fs
732        .write(&err_path, format!("{error:?}").as_bytes())
733        .await
734        .unwrap();
735
736    let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
737    let mut file = OpenOptions::new()
738        .create(true)
739        .append(true)
740        .open(&failed_jsonl_path)
741        .expect("Failed to open failed.jsonl");
742    writeln!(file, "{}", serde_json::to_string(example).unwrap())
743        .expect("Failed to write to failed.jsonl");
744
745    let cursor_path = example
746        .repo_name()
747        .unwrap()
748        .worktree_path()
749        .join(&example.spec.cursor_path);
750
751    let msg = format!(
752        indoc::indoc! {"
753            While processing \"{}\":
754
755            \x1b[31m{:?}\x1b[0m
756
757            Example:        \x1b[36m{}\x1b[0m
758            Error file:     \x1b[36m{}\x1b[0m
759            Cursor file:    \x1b[36m{}\x1b[0m
760            Re-run:         cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
761        "},
762        example.spec.name,
763        error,
764        failed_example_path.display(),
765        err_path.display(),
766        cursor_path.display(),
767        command,
768        failed_example_path.display(),
769    );
770    if args.failfast || failfast_on_single_example {
771        Progress::global().finalize();
772        panic!("{}", msg);
773    } else {
774        log::error!("{}", msg);
775    }
776}