main.rs

  1mod anthropic_client;
  2mod distill;
  3mod example;
  4mod format_prompt;
  5mod git;
  6mod headless;
  7mod load_project;
  8mod metrics;
  9mod paths;
 10mod predict;
 11mod progress;
 12mod pull_examples;
 13mod reorder_patch;
 14mod retrieve_context;
 15mod score;
 16mod split_commit;
 17mod split_dataset;
 18mod synthesize;
 19use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
 20use collections::HashSet;
 21use edit_prediction::EditPredictionStore;
 22use futures::channel::mpsc;
 23use futures::{SinkExt as _, StreamExt as _};
 24use gpui::{AppContext as _, Application};
 25
 26use reqwest_client::ReqwestClient;
 27use serde::{Deserialize, Serialize};
 28use std::fmt::Display;
 29use std::fs::{File, OpenOptions};
 30use std::hash::{Hash, Hasher};
 31use std::io::{BufRead, BufReader, BufWriter, Write};
 32use std::{path::PathBuf, sync::Arc};
 33
 34use crate::distill::run_distill;
 35use crate::example::{Example, group_examples_by_repo, read_example_files};
 36use crate::format_prompt::run_format_prompt;
 37use crate::load_project::run_load_project;
 38use crate::paths::FAILED_EXAMPLES_DIR;
 39use crate::predict::run_prediction;
 40use crate::progress::Progress;
 41use crate::retrieve_context::run_context_retrieval;
 42use crate::score::run_scoring;
 43use crate::split_commit::SplitCommitArgs;
 44use crate::split_dataset::SplitArgs;
 45use crate::synthesize::{SynthesizeConfig, run_synthesize};
 46
 47#[derive(Parser, Debug)]
 48#[command(name = "ep")]
 49struct EpArgs {
 50    #[arg(long, default_value_t = false)]
 51    printenv: bool,
 52    #[clap(long, default_value_t = 10, global = true)]
 53    max_parallelism: usize,
 54    #[clap(long, global = true)]
 55    limit: Option<usize>,
 56    /// Filter examples by name
 57    #[clap(long, global = true)]
 58    name: Option<String>,
 59    /// Filter examples by repository
 60    #[clap(long, global = true)]
 61    repo: Option<String>,
 62    #[command(subcommand)]
 63    command: Option<Command>,
 64    #[clap(global = true, help = INPUTS_HELP)]
 65    inputs: Vec<PathBuf>,
 66    #[arg(long, short, global = true)]
 67    output: Option<PathBuf>,
 68    #[arg(long, short, global = true)]
 69    in_place: bool,
 70    #[arg(long, short, global = true)]
 71    failfast: bool,
 72}
 73
 74const INPUTS_HELP: &str = r#"
 75Inputs can be file paths or special specifiers:
 76
 77  path
 78      Path to an example(s) file (.md, .json, or .jsonl)
 79
 80  captured-after:{timestamp}
 81      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 82
 83      You can specify this multiple times and mix it with file inputs.
 84
 85      Required environment variables to connect to Snowflake:
 86          EP_SNOWFLAKE_API_KEY
 87          EP_SNOWFLAKE_BASE_URL
 88
 89      Optional:
 90          EP_SNOWFLAKE_ROLE
 91
 92Examples:
 93
 94  # Predict from a file
 95  ep predict examples.jsonl
 96
 97  # Predict from captured examples after a timestamp
 98  ep predict captured-after:2025-01-01T00:00:00Z
 99
100  # Mix file inputs and captured-after in the same invocation
101  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
102"#;
103
104#[derive(Subcommand, Debug, Clone)]
105enum Command {
106    /// Parse markdown examples and output a combined .jsonl file
107    ParseExample,
108    /// Create git worktrees for each example and load file contents
109    LoadProject,
110    /// Retrieve context for input examples.
111    Context,
112    /// Generate a prompt string for a specific model
113    FormatPrompt(FormatPromptArgs),
114    /// Runs edit prediction
115    Predict(PredictArgs),
116    /// Computes a score based on actual and expected patches
117    Score(PredictArgs),
118    /// Prepares a distillation dataset by copying expected outputs to
119    /// predicted outputs and removing actual outputs and prompts.
120    Distill,
121    /// Print aggregated scores
122    Eval(PredictArgs),
123    /// Generate eval examples by analyzing git commits from a repository
124    Synthesize(SynthesizeArgs),
125    /// Remove git repositories and worktrees
126    Clean,
127    /// Generate an evaluation example by splitting a chronologically-ordered commit
128    SplitCommit(SplitCommitArgs),
129    /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
130    Split(SplitArgs),
131}
132
133impl Display for Command {
134    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135        match self {
136            Command::ParseExample => write!(f, "parse-example"),
137            Command::LoadProject => write!(f, "load-project"),
138            Command::Context => write!(f, "context"),
139            Command::FormatPrompt(format_prompt_args) => write!(
140                f,
141                "format-prompt --prompt-format={}",
142                format_prompt_args
143                    .prompt_format
144                    .to_possible_value()
145                    .unwrap()
146                    .get_name()
147            ),
148            Command::Predict(predict_args) => {
149                write!(
150                    f,
151                    "predict --provider={:?}",
152                    predict_args
153                        .provider
154                        .to_possible_value()
155                        .unwrap()
156                        .get_name()
157                )
158            }
159            Command::Score(predict_args) => {
160                write!(
161                    f,
162                    "score --provider={:?}",
163                    predict_args
164                        .provider
165                        .to_possible_value()
166                        .unwrap()
167                        .get_name()
168                )
169            }
170            Command::Distill => write!(f, "distill"),
171            Command::Eval(predict_args) => write!(
172                f,
173                "eval --provider={:?}",
174                predict_args
175                    .provider
176                    .to_possible_value()
177                    .unwrap()
178                    .get_name()
179            ),
180            Command::Synthesize(args) => {
181                write!(f, "synthesize --repo={}", args.repo)
182            }
183            Command::Clean => write!(f, "clean"),
184            Command::SplitCommit(_) => write!(f, "split-commit"),
185            Command::Split(_) => write!(f, "split"),
186        }
187    }
188}
189
190#[derive(Debug, Args, Clone)]
191struct FormatPromptArgs {
192    #[clap(long)]
193    prompt_format: PromptFormat,
194}
195
196#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
197enum PromptFormat {
198    Teacher,
199    Zeta2,
200}
201
202#[derive(Debug, Args, Clone)]
203struct PredictArgs {
204    #[clap(long)]
205    provider: PredictionProvider,
206    #[clap(long, default_value_t = 1)]
207    repetitions: usize,
208}
209
210#[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)]
211enum PredictionProvider {
212    Sweep,
213    Mercury,
214    Zeta1,
215    Zeta2,
216    Teacher,
217    TeacherNonBatching,
218}
219
220#[derive(Debug, Args, Clone)]
221struct SynthesizeArgs {
222    /// Repository URL (git@github.com:owner/repo or https://...)
223    #[clap(long)]
224    repo: String,
225
226    /// Number of examples to generate
227    #[clap(long, default_value_t = 5)]
228    count: usize,
229
230    /// Maximum commits to scan before giving up
231    #[clap(long, default_value_t = 100)]
232    max_commits: usize,
233
234    /// Ignore state file and reprocess all commits
235    #[clap(long)]
236    fresh: bool,
237}
238
239impl EpArgs {
240    fn output_path(&self) -> Option<PathBuf> {
241        if self.in_place {
242            if self.inputs.len() == 1 {
243                self.inputs.first().cloned()
244            } else {
245                panic!("--in-place requires exactly one input file")
246            }
247        } else {
248            self.output.clone()
249        }
250    }
251}
252
253async fn load_examples(
254    http_client: Arc<dyn http_client::HttpClient>,
255    args: &EpArgs,
256    output_path: Option<&PathBuf>,
257) -> anyhow::Result<Vec<Example>> {
258    let mut captured_after_timestamps = Vec::new();
259    let mut file_inputs = Vec::new();
260
261    for input in &args.inputs {
262        let input_string = input.to_string_lossy();
263        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
264            captured_after_timestamps.push(timestamp.to_string());
265        } else {
266            file_inputs.push(input.clone());
267        }
268    }
269
270    let mut examples = read_example_files(&file_inputs);
271
272    Progress::global().set_total_examples(examples.len());
273
274    let remaining_limit_for_snowflake =
275        args.limit.map(|limit| limit.saturating_sub(examples.len()));
276
277    if let Some(0) = remaining_limit_for_snowflake {
278        log::info!(
279            "skipping captured-after inputs because --limit is already satisfied by example files"
280        );
281    } else if !captured_after_timestamps.is_empty() {
282        captured_after_timestamps.sort();
283
284        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
285
286        let mut captured_examples = pull_examples::fetch_captured_examples_after(
287            http_client,
288            &captured_after_timestamps,
289            max_rows_per_timestamp,
290        )
291        .await?;
292        examples.append(&mut captured_examples);
293    }
294
295    crate::example::sort_examples_by_repo_and_rev(&mut examples);
296
297    if let Some(name_filter) = &args.name {
298        examples.retain(|example| example.spec.name.contains(name_filter));
299    }
300    if let Some(repo_filter) = &args.repo {
301        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
302    }
303
304    if let Some(limit) = args.limit {
305        if examples.len() > limit {
306            examples.truncate(limit);
307        }
308    }
309
310    if let Some(path) = output_path {
311        resume_from_output(path, &mut examples);
312    }
313
314    Progress::global().set_total_examples(examples.len());
315
316    Ok(examples)
317}
318
319fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
320    let mut hasher = collections::FxHasher::default();
321    spec.hash(&mut hasher);
322    hasher.finish()
323}
324
325fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
326    let file = match File::open(path) {
327        Ok(f) => f,
328        Err(_) => return,
329    };
330
331    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
332
333    let reader = BufReader::new(file);
334    let mut kept_lines = Vec::new();
335    let mut kept_hashes = HashSet::default();
336
337    for line in reader.lines() {
338        let line = match line {
339            Ok(l) => l,
340            Err(_) => continue,
341        };
342
343        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
344            let hash = spec_hash(&output_example.spec);
345            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
346                kept_hashes.insert(hash);
347                kept_lines.push(line);
348            }
349        }
350    }
351
352    let total = examples.len();
353    let already_processed = kept_hashes.len();
354
355    eprintln!(
356        "Resuming: {}/{} examples already processed",
357        already_processed, total
358    );
359
360    let file = OpenOptions::new()
361        .write(true)
362        .truncate(true)
363        .open(path)
364        .expect("Failed to open output file for rewriting");
365    let mut writer = BufWriter::new(file);
366    for line in &kept_lines {
367        writeln!(writer, "{}", line).expect("Failed to write to output file");
368    }
369    writer.flush().expect("Failed to flush output file");
370
371    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
372}
373
374fn main() {
375    let args = EpArgs::parse();
376
377    if args.printenv {
378        ::util::shell_env::print_env();
379        return;
380    }
381
382    let output = args.output_path();
383    let command = match &args.command {
384        Some(cmd) => cmd.clone(),
385        None => {
386            EpArgs::command().print_help().unwrap();
387            return;
388        }
389    };
390
391    match &command {
392        Command::Clean => {
393            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
394            return;
395        }
396        Command::Synthesize(synth_args) => {
397            let Some(output_dir) = args.output else {
398                panic!("output dir is required");
399            };
400            let config = SynthesizeConfig {
401                repo_url: synth_args.repo.clone(),
402                count: synth_args.count,
403                max_commits: synth_args.max_commits,
404                output_dir,
405                fresh: synth_args.fresh,
406            };
407            smol::block_on(async {
408                if let Err(e) = run_synthesize(config).await {
409                    eprintln!("Error: {:?}", e);
410                    std::process::exit(1);
411                }
412            });
413            return;
414        }
415        Command::SplitCommit(split_commit_args) => {
416            if let Err(error) =
417                split_commit::run_split_commit(split_commit_args, &args.inputs, output.as_ref())
418            {
419                eprintln!("{error:#}");
420                std::process::exit(1);
421            }
422            return;
423        }
424        Command::Split(split_args) => {
425            if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
426                eprintln!("{error:#}");
427                std::process::exit(1);
428            }
429            return;
430        }
431        _ => {}
432    }
433
434    let http_client = Arc::new(ReqwestClient::new());
435    let app = Application::headless().with_http_client(http_client);
436
437    app.run(move |cx| {
438        let app_state = Arc::new(headless::init(cx));
439        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
440
441        cx.spawn(async move |cx| {
442            let result = async {
443                let mut examples =
444                    load_examples(app_state.client.http_client(), &args, output.as_ref()).await?;
445
446                match &command {
447                    Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
448                        predict::sync_batches(&args.provider).await?;
449                    }
450                    _ => (),
451                }
452
453                let failfast_on_single_example = examples.len() == 1;
454
455                let output_sender: Option<mpsc::UnboundedSender<String>> =
456                    if args.output.is_some() || !matches!(command, Command::Eval(_)) {
457                        output.as_ref().map(|path| {
458                            let file = OpenOptions::new()
459                                .create(true)
460                                .append(true)
461                                .open(path)
462                                .expect("Failed to open output file");
463                            let mut writer = BufWriter::new(file);
464                            let (sender, mut receiver) = mpsc::unbounded::<String>();
465                            cx.background_spawn(async move {
466                                while let Some(line) = receiver.next().await {
467                                    writeln!(writer, "{}", line).expect("Failed to write example");
468                                    writer.flush().expect("Failed to flush output");
469                                }
470                            })
471                            .detach();
472                            sender
473                        })
474                    } else {
475                        None
476                    };
477
478                let mut grouped_examples = group_examples_by_repo(&mut examples);
479                let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
480
481                for example_batch in example_batches {
482                    let futures = example_batch.into_iter().map(|repo_examples| async {
483                        for example in repo_examples.iter_mut() {
484                            let result = async {
485                                match &command {
486                                    Command::ParseExample => {}
487                                    Command::LoadProject => {
488                                        run_load_project(example, app_state.clone(), cx.clone())
489                                            .await?;
490                                    }
491                                    Command::Context => {
492                                        run_context_retrieval(
493                                            example,
494                                            app_state.clone(),
495                                            cx.clone(),
496                                        )
497                                        .await?;
498                                    }
499                                    Command::FormatPrompt(args) => {
500                                        run_format_prompt(
501                                            example,
502                                            args.prompt_format,
503                                            app_state.clone(),
504                                            cx.clone(),
505                                        )
506                                        .await?;
507                                    }
508                                    Command::Predict(args) => {
509                                        run_prediction(
510                                            example,
511                                            Some(args.provider),
512                                            args.repetitions,
513                                            app_state.clone(),
514                                            cx.clone(),
515                                        )
516                                        .await?;
517                                    }
518                                    Command::Distill => {
519                                        run_distill(example).await?;
520                                    }
521                                    Command::Score(args) | Command::Eval(args) => {
522                                        run_scoring(example, &args, app_state.clone(), cx.clone())
523                                            .await?;
524                                    }
525                                    Command::Clean
526                                    | Command::Synthesize(_)
527                                    | Command::SplitCommit(_)
528                                    | Command::Split(_) => {
529                                        unreachable!()
530                                    }
531                                }
532                                anyhow::Ok(())
533                            }
534                            .await;
535
536                            if let Err(error) = result {
537                                handle_error(
538                                    error,
539                                    &args,
540                                    &command,
541                                    &app_state,
542                                    failfast_on_single_example,
543                                    example,
544                                )
545                                .await;
546                            }
547
548                            if let Some(ref mut sender) = output_sender.clone() {
549                                let line = serde_json::to_string(example).unwrap();
550                                sender
551                                    .send(line)
552                                    .await
553                                    .expect("Failed to send to output writer");
554                            } else if args.output.is_none() && !matches!(command, Command::Eval(_))
555                            {
556                                let line = serde_json::to_string(example).unwrap();
557                                println!("{}", line);
558                            }
559                        }
560                    });
561                    futures::future::join_all(futures).await;
562                }
563
564                Progress::global().finalize();
565
566                match &command {
567                    Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
568                        predict::sync_batches(&args.provider).await?;
569                    }
570                    _ => (),
571                }
572
573                match &command {
574                    Command::Eval(_) => score::print_report(&examples),
575                    _ => (),
576                };
577
578                anyhow::Ok(())
579            }
580            .await;
581
582            if let Err(e) = result {
583                panic!("Fatal error: {:?}", e);
584            }
585
586            let _ = cx.update(|cx| cx.quit());
587        })
588        .detach();
589    });
590}
591
592async fn handle_error(
593    error: anyhow::Error,
594    args: &EpArgs,
595    command: &Command,
596    app_state: &Arc<headless::EpAppState>,
597    failfast_on_single_example: bool,
598    example: &Example,
599) {
600    Progress::global().increment_failed();
601    let example_name = example.spec.filename();
602    let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
603    app_state
604        .fs
605        .write(
606            &failed_example_path,
607            &serde_json::to_vec_pretty(&example).unwrap(),
608        )
609        .await
610        .unwrap();
611    let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
612    app_state
613        .fs
614        .write(&err_path, format!("{error:?}").as_bytes())
615        .await
616        .unwrap();
617
618    let cursor_path = example
619        .repo_name()
620        .unwrap()
621        .worktree_path()
622        .join(&example.spec.cursor_path);
623
624    let msg = format!(
625        indoc::indoc! {"
626            While processing \"{}\":
627
628            \x1b[31m{:?}\x1b[0m
629
630            Example:        \x1b[36m{}\x1b[0m
631            Error file:     \x1b[36m{}\x1b[0m
632            Cursor file:    \x1b[36m{}\x1b[0m
633            Re-run:         cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
634        "},
635        example.spec.name,
636        error,
637        failed_example_path.display(),
638        err_path.display(),
639        cursor_path.display(),
640        command,
641        failed_example_path.display(),
642    );
643    if args.failfast || failfast_on_single_example {
644        Progress::global().finalize();
645        panic!("{}", msg);
646    } else {
647        log::error!("{}", msg);
648    }
649}