main.rs

  1mod anthropic_client;
  2mod distill;
  3mod example;
  4mod format_prompt;
  5mod git;
  6mod headless;
  7mod load_project;
  8mod metrics;
  9mod paths;
 10mod predict;
 11mod progress;
 12mod pull_examples;
 13mod reorder_patch;
 14mod retrieve_context;
 15mod score;
 16mod split_commit;
 17mod split_dataset;
 18mod synthesize;
 19use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
 20use collections::HashSet;
 21use edit_prediction::EditPredictionStore;
 22use futures::channel::mpsc;
 23use futures::{SinkExt as _, StreamExt as _};
 24use gpui::{AppContext as _, Application};
 25
 26use reqwest_client::ReqwestClient;
 27use serde::{Deserialize, Serialize};
 28use std::fmt::Display;
 29use std::fs::{File, OpenOptions};
 30use std::hash::{Hash, Hasher};
 31use std::io::{BufRead, BufReader, BufWriter, Write};
 32use std::{path::PathBuf, sync::Arc};
 33
 34use crate::distill::run_distill;
 35use crate::example::{Example, group_examples_by_repo, read_example_files};
 36use crate::format_prompt::run_format_prompt;
 37use crate::load_project::run_load_project;
 38use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
 39use crate::predict::run_prediction;
 40use crate::progress::Progress;
 41use crate::retrieve_context::run_context_retrieval;
 42use crate::score::run_scoring;
 43use crate::split_commit::SplitCommitArgs;
 44use crate::split_dataset::SplitArgs;
 45use crate::synthesize::{SynthesizeConfig, run_synthesize};
 46
 47#[derive(Parser, Debug)]
 48#[command(name = "ep")]
 49struct EpArgs {
 50    #[arg(long, default_value_t = false)]
 51    printenv: bool,
 52    #[clap(long, default_value_t = 10, global = true)]
 53    max_parallelism: usize,
 54    #[clap(long, global = true)]
 55    limit: Option<usize>,
 56    /// Filter examples by name
 57    #[clap(long, global = true)]
 58    name: Option<String>,
 59    /// Filter examples by repository
 60    #[clap(long, global = true)]
 61    repo: Option<String>,
 62    #[command(subcommand)]
 63    command: Option<Command>,
 64    #[clap(global = true, help = INPUTS_HELP)]
 65    inputs: Vec<PathBuf>,
 66    #[arg(long, short, global = true)]
 67    output: Option<PathBuf>,
 68    #[arg(long, short, global = true)]
 69    in_place: bool,
 70    #[arg(long, short, global = true)]
 71    failfast: bool,
 72    /// How to handle failed examples in output: keep them or skip them.
 73    /// Failed examples are always logged to the run's failed directory.
 74    #[arg(long, global = true, default_value = "keep")]
 75    failed: FailedHandling,
 76}
 77
 78/// Controls whether failed examples are included in the main output.
 79/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
 80#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
 81pub enum FailedHandling {
 82    /// Include failed examples in the main output (default)
 83    #[default]
 84    Keep,
 85    /// Exclude failed examples from the main output
 86    Skip,
 87}
 88
 89const INPUTS_HELP: &str = r#"
 90Inputs can be file paths or special specifiers:
 91
 92  path
 93      Path to an example(s) file (.md, .json, or .jsonl)
 94
 95  captured-after:{timestamp}
 96      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 97
 98      You can specify this multiple times and mix it with file inputs.
 99
100      Required environment variables to connect to Snowflake:
101          EP_SNOWFLAKE_API_KEY
102          EP_SNOWFLAKE_BASE_URL
103
104      Optional:
105          EP_SNOWFLAKE_ROLE
106
107Examples:
108
109  # Predict from a file
110  ep predict examples.jsonl
111
112  # Predict from captured examples after a timestamp
113  ep predict captured-after:2025-01-01T00:00:00Z
114
115  # Mix file inputs and captured-after in the same invocation
116  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
117"#;
118
119#[derive(Subcommand, Debug, Clone)]
120enum Command {
121    /// Parse markdown examples and output a combined .jsonl file
122    ParseExample,
123    /// Create git worktrees for each example and load file contents
124    LoadProject,
125    /// Retrieve context for input examples.
126    Context,
127    /// Generate a prompt string for a specific model
128    FormatPrompt(FormatPromptArgs),
129    /// Runs edit prediction
130    Predict(PredictArgs),
131    /// Computes a score based on actual and expected patches
132    Score(PredictArgs),
133    /// Prepares a distillation dataset by copying expected outputs to
134    /// predicted outputs and removing actual outputs and prompts.
135    Distill,
136    /// Print aggregated scores
137    Eval(PredictArgs),
138    /// Generate eval examples by analyzing git commits from a repository
139    Synthesize(SynthesizeArgs),
140    /// Remove git repositories and worktrees
141    Clean,
142    /// Generate an evaluation example by splitting a chronologically-ordered commit
143    SplitCommit(SplitCommitArgs),
144    /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
145    Split(SplitArgs),
146}
147
148impl Display for Command {
149    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150        match self {
151            Command::ParseExample => write!(f, "parse-example"),
152            Command::LoadProject => write!(f, "load-project"),
153            Command::Context => write!(f, "context"),
154            Command::FormatPrompt(format_prompt_args) => write!(
155                f,
156                "format-prompt --prompt-format={}",
157                format_prompt_args
158                    .prompt_format
159                    .to_possible_value()
160                    .unwrap()
161                    .get_name()
162            ),
163            Command::Predict(predict_args) => {
164                write!(
165                    f,
166                    "predict --provider={:?}",
167                    predict_args
168                        .provider
169                        .to_possible_value()
170                        .unwrap()
171                        .get_name()
172                )
173            }
174            Command::Score(predict_args) => {
175                write!(
176                    f,
177                    "score --provider={:?}",
178                    predict_args
179                        .provider
180                        .to_possible_value()
181                        .unwrap()
182                        .get_name()
183                )
184            }
185            Command::Distill => write!(f, "distill"),
186            Command::Eval(predict_args) => write!(
187                f,
188                "eval --provider={:?}",
189                predict_args
190                    .provider
191                    .to_possible_value()
192                    .unwrap()
193                    .get_name()
194            ),
195            Command::Synthesize(args) => {
196                write!(f, "synthesize --repo={}", args.repo)
197            }
198            Command::Clean => write!(f, "clean"),
199            Command::SplitCommit(_) => write!(f, "split-commit"),
200            Command::Split(_) => write!(f, "split"),
201        }
202    }
203}
204
205#[derive(Debug, Args, Clone)]
206struct FormatPromptArgs {
207    #[clap(long, short('p'))]
208    prompt_format: PromptFormat,
209}
210
211#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
212enum PromptFormat {
213    Teacher,
214    Zeta2,
215}
216
217#[derive(Debug, Args, Clone)]
218struct PredictArgs {
219    #[clap(long)]
220    provider: PredictionProvider,
221    #[clap(long, default_value_t = 1)]
222    repetitions: usize,
223}
224
225#[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)]
226enum PredictionProvider {
227    Sweep,
228    Mercury,
229    Zeta1,
230    Zeta2,
231    Teacher,
232    TeacherNonBatching,
233}
234
235#[derive(Debug, Args, Clone)]
236struct SynthesizeArgs {
237    /// Repository URL (git@github.com:owner/repo or https://...)
238    #[clap(long)]
239    repo: String,
240
241    /// Number of examples to generate
242    #[clap(long, default_value_t = 5)]
243    count: usize,
244
245    /// Maximum commits to scan before giving up
246    #[clap(long, default_value_t = 100)]
247    max_commits: usize,
248
249    /// Ignore state file and reprocess all commits
250    #[clap(long)]
251    fresh: bool,
252}
253
254impl EpArgs {
255    fn output_path(&self) -> Option<PathBuf> {
256        if self.in_place {
257            if self.inputs.len() == 1 {
258                self.inputs.first().cloned()
259            } else {
260                panic!("--in-place requires exactly one input file")
261            }
262        } else {
263            self.output.clone()
264        }
265    }
266}
267
268async fn load_examples(
269    http_client: Arc<dyn http_client::HttpClient>,
270    args: &EpArgs,
271    output_path: Option<&PathBuf>,
272) -> anyhow::Result<Vec<Example>> {
273    let mut captured_after_timestamps = Vec::new();
274    let mut file_inputs = Vec::new();
275
276    for input in &args.inputs {
277        let input_string = input.to_string_lossy();
278        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
279            captured_after_timestamps.push(timestamp.to_string());
280        } else {
281            file_inputs.push(input.clone());
282        }
283    }
284
285    let mut examples = read_example_files(&file_inputs);
286
287    Progress::global().set_total_examples(examples.len());
288
289    let remaining_limit_for_snowflake =
290        args.limit.map(|limit| limit.saturating_sub(examples.len()));
291
292    if let Some(0) = remaining_limit_for_snowflake {
293        log::info!(
294            "skipping captured-after inputs because --limit is already satisfied by example files"
295        );
296    } else if !captured_after_timestamps.is_empty() {
297        captured_after_timestamps.sort();
298
299        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
300
301        let mut captured_examples = pull_examples::fetch_captured_examples_after(
302            http_client,
303            &captured_after_timestamps,
304            max_rows_per_timestamp,
305        )
306        .await?;
307        examples.append(&mut captured_examples);
308    }
309
310    crate::example::sort_examples_by_repo_and_rev(&mut examples);
311
312    if let Some(name_filter) = &args.name {
313        examples.retain(|example| example.spec.name.contains(name_filter));
314    }
315    if let Some(repo_filter) = &args.repo {
316        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
317    }
318
319    if let Some(limit) = args.limit {
320        if examples.len() > limit {
321            examples.truncate(limit);
322        }
323    }
324
325    if let Some(path) = output_path {
326        resume_from_output(path, &mut examples);
327    }
328
329    Progress::global().set_total_examples(examples.len());
330
331    Ok(examples)
332}
333
334fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
335    let mut hasher = collections::FxHasher::default();
336    spec.hash(&mut hasher);
337    hasher.finish()
338}
339
340fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
341    let file = match File::open(path) {
342        Ok(f) => f,
343        Err(_) => return,
344    };
345
346    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
347
348    let reader = BufReader::new(file);
349    let mut kept_lines = Vec::new();
350    let mut kept_hashes = HashSet::default();
351
352    for line in reader.lines() {
353        let line = match line {
354            Ok(l) => l,
355            Err(_) => continue,
356        };
357
358        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
359            let hash = spec_hash(&output_example.spec);
360            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
361                kept_hashes.insert(hash);
362                kept_lines.push(line);
363            }
364        }
365    }
366
367    let total = examples.len();
368    let already_processed = kept_hashes.len();
369
370    eprintln!(
371        "Resuming: {}/{} examples already processed",
372        already_processed, total
373    );
374
375    let file = OpenOptions::new()
376        .write(true)
377        .truncate(true)
378        .open(path)
379        .expect("Failed to open output file for rewriting");
380    let mut writer = BufWriter::new(file);
381    for line in &kept_lines {
382        writeln!(writer, "{}", line).expect("Failed to write to output file");
383    }
384    writer.flush().expect("Failed to flush output file");
385
386    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
387}
388
389fn main() {
390    let args = EpArgs::parse();
391
392    if args.printenv {
393        ::util::shell_env::print_env();
394        return;
395    }
396
397    let output = args.output_path();
398    let command = match &args.command {
399        Some(cmd) => cmd.clone(),
400        None => {
401            EpArgs::command().print_help().unwrap();
402            return;
403        }
404    };
405
406    match &command {
407        Command::Clean => {
408            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
409            return;
410        }
411        Command::Synthesize(synth_args) => {
412            let Some(output_dir) = args.output else {
413                panic!("output dir is required");
414            };
415            let config = SynthesizeConfig {
416                repo_url: synth_args.repo.clone(),
417                count: synth_args.count,
418                max_commits: synth_args.max_commits,
419                output_dir,
420                fresh: synth_args.fresh,
421            };
422            smol::block_on(async {
423                if let Err(e) = run_synthesize(config).await {
424                    eprintln!("Error: {:?}", e);
425                    std::process::exit(1);
426                }
427            });
428            return;
429        }
430        Command::SplitCommit(split_commit_args) => {
431            if let Err(error) =
432                split_commit::run_split_commit(split_commit_args, &args.inputs, output.as_ref())
433            {
434                eprintln!("{error:#}");
435                std::process::exit(1);
436            }
437            return;
438        }
439        Command::Split(split_args) => {
440            if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
441                eprintln!("{error:#}");
442                std::process::exit(1);
443            }
444            return;
445        }
446        _ => {}
447    }
448
449    let http_client = Arc::new(ReqwestClient::new());
450    let app = Application::headless().with_http_client(http_client);
451
452    app.run(move |cx| {
453        let app_state = Arc::new(headless::init(cx));
454        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
455
456        cx.spawn(async move |cx| {
457            let result = async {
458                let mut examples =
459                    load_examples(app_state.client.http_client(), &args, output.as_ref()).await?;
460
461                match &command {
462                    Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
463                        predict::sync_batches(&args.provider).await?;
464                    }
465                    _ => (),
466                }
467
468                let failfast_on_single_example = examples.len() == 1;
469
470                let output_sender: Option<mpsc::UnboundedSender<String>> =
471                    if args.output.is_some() || !matches!(command, Command::Eval(_)) {
472                        output.as_ref().map(|path| {
473                            let file = OpenOptions::new()
474                                .create(true)
475                                .append(true)
476                                .open(path)
477                                .expect("Failed to open output file");
478                            let mut writer = BufWriter::new(file);
479                            let (sender, mut receiver) = mpsc::unbounded::<String>();
480                            cx.background_spawn(async move {
481                                while let Some(line) = receiver.next().await {
482                                    writeln!(writer, "{}", line).expect("Failed to write example");
483                                    writer.flush().expect("Failed to flush output");
484                                }
485                            })
486                            .detach();
487                            sender
488                        })
489                    } else {
490                        None
491                    };
492
493                let mut grouped_examples = group_examples_by_repo(&mut examples);
494                let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
495
496                for example_batch in example_batches {
497                    let futures = example_batch.into_iter().map(|repo_examples| async {
498                        for example in repo_examples.iter_mut() {
499                            let result = async {
500                                match &command {
501                                    Command::ParseExample => {}
502                                    Command::LoadProject => {
503                                        run_load_project(example, app_state.clone(), cx.clone())
504                                            .await?;
505                                    }
506                                    Command::Context => {
507                                        run_context_retrieval(
508                                            example,
509                                            app_state.clone(),
510                                            cx.clone(),
511                                        )
512                                        .await?;
513                                    }
514                                    Command::FormatPrompt(args) => {
515                                        run_format_prompt(
516                                            example,
517                                            args.prompt_format,
518                                            app_state.clone(),
519                                            cx.clone(),
520                                        )
521                                        .await?;
522                                    }
523                                    Command::Predict(args) => {
524                                        run_prediction(
525                                            example,
526                                            Some(args.provider),
527                                            args.repetitions,
528                                            app_state.clone(),
529                                            cx.clone(),
530                                        )
531                                        .await?;
532                                    }
533                                    Command::Distill => {
534                                        run_distill(example).await?;
535                                    }
536                                    Command::Score(args) | Command::Eval(args) => {
537                                        run_scoring(example, &args, app_state.clone(), cx.clone())
538                                            .await?;
539                                    }
540                                    Command::Clean
541                                    | Command::Synthesize(_)
542                                    | Command::SplitCommit(_)
543                                    | Command::Split(_) => {
544                                        unreachable!()
545                                    }
546                                }
547                                anyhow::Ok(())
548                            }
549                            .await;
550
551                            let failed = if let Err(error) = result {
552                                handle_error(
553                                    error,
554                                    &args,
555                                    &command,
556                                    &app_state,
557                                    failfast_on_single_example,
558                                    example,
559                                )
560                                .await;
561                                true
562                            } else {
563                                false
564                            };
565
566                            let should_write = !failed || args.failed == FailedHandling::Keep;
567                            if should_write {
568                                if let Some(ref mut sender) = output_sender.clone() {
569                                    let line = serde_json::to_string(example).unwrap();
570                                    sender
571                                        .send(line)
572                                        .await
573                                        .expect("Failed to send to output writer");
574                                } else if args.output.is_none()
575                                    && !matches!(command, Command::Eval(_))
576                                {
577                                    let line = serde_json::to_string(example).unwrap();
578                                    println!("{}", line);
579                                }
580                            }
581                        }
582                    });
583                    futures::future::join_all(futures).await;
584                }
585
586                Progress::global().finalize();
587
588                match &command {
589                    Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
590                        predict::sync_batches(&args.provider).await?;
591                    }
592                    _ => (),
593                }
594
595                match &command {
596                    Command::Eval(_) => score::print_report(&examples),
597                    _ => (),
598                };
599
600                anyhow::Ok(())
601            }
602            .await;
603
604            if let Err(e) = result {
605                panic!("Fatal error: {:?}", e);
606            }
607
608            let _ = cx.update(|cx| cx.quit());
609        })
610        .detach();
611    });
612}
613
614async fn handle_error(
615    error: anyhow::Error,
616    args: &EpArgs,
617    command: &Command,
618    app_state: &Arc<headless::EpAppState>,
619    failfast_on_single_example: bool,
620    example: &Example,
621) {
622    Progress::global().increment_failed();
623    let example_name = example.spec.filename();
624    let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
625    app_state
626        .fs
627        .write(
628            &failed_example_path,
629            &serde_json::to_vec_pretty(&example).unwrap(),
630        )
631        .await
632        .unwrap();
633    let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
634    app_state
635        .fs
636        .write(&err_path, format!("{error:?}").as_bytes())
637        .await
638        .unwrap();
639
640    let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
641    let mut file = OpenOptions::new()
642        .create(true)
643        .append(true)
644        .open(&failed_jsonl_path)
645        .expect("Failed to open failed.jsonl");
646    writeln!(file, "{}", serde_json::to_string(example).unwrap())
647        .expect("Failed to write to failed.jsonl");
648
649    let cursor_path = example
650        .repo_name()
651        .unwrap()
652        .worktree_path()
653        .join(&example.spec.cursor_path);
654
655    let msg = format!(
656        indoc::indoc! {"
657            While processing \"{}\":
658
659            \x1b[31m{:?}\x1b[0m
660
661            Example:        \x1b[36m{}\x1b[0m
662            Error file:     \x1b[36m{}\x1b[0m
663            Cursor file:    \x1b[36m{}\x1b[0m
664            Re-run:         cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
665        "},
666        example.spec.name,
667        error,
668        failed_example_path.display(),
669        err_path.display(),
670        cursor_path.display(),
671        command,
672        failed_example_path.display(),
673    );
674    if args.failfast || failfast_on_single_example {
675        Progress::global().finalize();
676        panic!("{}", msg);
677    } else {
678        log::error!("{}", msg);
679    }
680}