main.rs

  1mod anthropic_client;
  2mod distill;
  3mod example;
  4mod format_prompt;
  5mod git;
  6mod headless;
  7mod load_project;
  8mod metrics;
  9mod paths;
 10mod predict;
 11mod progress;
 12mod pull_examples;
 13mod reorder_patch;
 14mod retrieve_context;
 15mod score;
 16mod split_commit;
 17mod split_dataset;
 18mod synthesize;
 19use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
 20use collections::HashSet;
 21use edit_prediction::EditPredictionStore;
 22use futures::channel::mpsc;
 23use futures::{SinkExt as _, StreamExt as _};
 24use gpui::{AppContext as _, Application};
 25
 26use reqwest_client::ReqwestClient;
 27use serde::{Deserialize, Serialize};
 28use std::fmt::Display;
 29use std::fs::{File, OpenOptions};
 30use std::hash::{Hash, Hasher};
 31use std::io::{BufRead, BufReader, BufWriter, Write};
 32use std::{path::PathBuf, sync::Arc};
 33
 34use crate::distill::run_distill;
 35use crate::example::{Example, group_examples_by_repo, read_example_files};
 36use crate::format_prompt::run_format_prompt;
 37use crate::load_project::run_load_project;
 38use crate::paths::FAILED_EXAMPLES_DIR;
 39use crate::predict::run_prediction;
 40use crate::progress::Progress;
 41use crate::retrieve_context::run_context_retrieval;
 42use crate::score::run_scoring;
 43use crate::split_commit::SplitCommitArgs;
 44use crate::split_dataset::SplitArgs;
 45use crate::synthesize::{SynthesizeConfig, run_synthesize};
 46
 47#[derive(Parser, Debug)]
 48#[command(name = "ep")]
 49struct EpArgs {
 50    #[arg(long, default_value_t = false)]
 51    printenv: bool,
 52    #[clap(long, default_value_t = 10, global = true)]
 53    max_parallelism: usize,
 54    #[clap(long, global = true)]
 55    limit: Option<usize>,
 56    /// Filter examples by name
 57    #[clap(long, global = true)]
 58    name: Option<String>,
 59    /// Filter examples by repository
 60    #[clap(long, global = true)]
 61    repo: Option<String>,
 62    #[command(subcommand)]
 63    command: Option<Command>,
 64    #[clap(global = true, help = INPUTS_HELP)]
 65    inputs: Vec<PathBuf>,
 66    #[arg(long, short, global = true)]
 67    output: Option<PathBuf>,
 68    #[arg(long, short, global = true)]
 69    in_place: bool,
 70    #[arg(long, short, global = true)]
 71    failfast: bool,
 72}
 73
 74const INPUTS_HELP: &str = r#"
 75Inputs can be file paths or special specifiers:
 76
 77  path
 78      Path to an example(s) file (.md, .json, or .jsonl)
 79
 80  captured-after:{timestamp}
 81      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 82
 83      You can specify this multiple times and mix it with file inputs.
 84
 85      Required environment variables to connect to Snowflake:
 86          EP_SNOWFLAKE_API_KEY
 87          EP_SNOWFLAKE_BASE_URL
 88
 89      Optional:
 90          EP_SNOWFLAKE_ROLE
 91
 92Examples:
 93
 94  # Predict from a file
 95  ep predict examples.jsonl
 96
 97  # Predict from captured examples after a timestamp
 98  ep predict captured-after:2025-01-01T00:00:00Z
 99
100  # Mix file inputs and captured-after in the same invocation
101  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
102"#;
103
104#[derive(Subcommand, Debug, Clone)]
105enum Command {
106    /// Parse markdown examples and output a combined .jsonl file
107    ParseExample,
108    /// Create git worktrees for each example and load file contents
109    LoadProject,
110    /// Retrieve context for input examples.
111    Context,
112    /// Generate a prompt string for a specific model
113    FormatPrompt(FormatPromptArgs),
114    /// Runs edit prediction
115    Predict(PredictArgs),
116    /// Computes a score based on actual and expected patches
117    Score(PredictArgs),
118    /// Prepares a distillation dataset by copying expected outputs to
119    /// predicted outputs and removing actual outputs and prompts.
120    Distill,
121    /// Print aggregated scores
122    Eval(PredictArgs),
123    /// Generate eval examples by analyzing git commits from a repository
124    Synthesize(SynthesizeArgs),
125    /// Remove git repositories and worktrees
126    Clean,
127    /// Generate an evaluation example by splitting a chronologically-ordered commit
128    SplitCommit(SplitCommitArgs),
129    /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
130    Split(SplitArgs),
131}
132
133impl Display for Command {
134    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135        match self {
136            Command::ParseExample => write!(f, "parse-example"),
137            Command::LoadProject => write!(f, "load-project"),
138            Command::Context => write!(f, "context"),
139            Command::FormatPrompt(format_prompt_args) => write!(
140                f,
141                "format-prompt --prompt-format={}",
142                format_prompt_args
143                    .prompt_format
144                    .to_possible_value()
145                    .unwrap()
146                    .get_name()
147            ),
148            Command::Predict(predict_args) => {
149                write!(
150                    f,
151                    "predict --provider={:?}",
152                    predict_args
153                        .provider
154                        .to_possible_value()
155                        .unwrap()
156                        .get_name()
157                )
158            }
159            Command::Score(predict_args) => {
160                write!(
161                    f,
162                    "score --provider={:?}",
163                    predict_args
164                        .provider
165                        .to_possible_value()
166                        .unwrap()
167                        .get_name()
168                )
169            }
170            Command::Distill => write!(f, "distill"),
171            Command::Eval(predict_args) => write!(
172                f,
173                "eval --provider={:?}",
174                predict_args
175                    .provider
176                    .to_possible_value()
177                    .unwrap()
178                    .get_name()
179            ),
180            Command::Synthesize(args) => {
181                write!(f, "synthesize --repo={}", args.repo)
182            }
183            Command::Clean => write!(f, "clean"),
184            Command::SplitCommit(_) => write!(f, "split-commit"),
185            Command::Split(_) => write!(f, "split"),
186        }
187    }
188}
189
190#[derive(Debug, Args, Clone)]
191struct FormatPromptArgs {
192    #[clap(long)]
193    prompt_format: PromptFormat,
194}
195
196#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
197enum PromptFormat {
198    Teacher,
199    Zeta2,
200}
201
202#[derive(Debug, Args, Clone)]
203struct PredictArgs {
204    #[clap(long)]
205    provider: PredictionProvider,
206    #[clap(long, default_value_t = 1)]
207    repetitions: usize,
208}
209
210#[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)]
211enum PredictionProvider {
212    Sweep,
213    Mercury,
214    Zeta1,
215    Zeta2,
216    Teacher,
217    TeacherNonBatching,
218}
219
220#[derive(Debug, Args, Clone)]
221struct SynthesizeArgs {
222    /// Repository URL (git@github.com:owner/repo or https://...)
223    #[clap(long)]
224    repo: String,
225
226    /// Number of examples to generate
227    #[clap(long, default_value_t = 5)]
228    count: usize,
229
230    /// Maximum commits to scan before giving up
231    #[clap(long, default_value_t = 100)]
232    max_commits: usize,
233
234    /// Ignore state file and reprocess all commits
235    #[clap(long)]
236    fresh: bool,
237}
238
239impl EpArgs {
240    fn output_path(&self) -> Option<PathBuf> {
241        if self.in_place {
242            if self.inputs.len() == 1 {
243                self.inputs.first().cloned()
244            } else {
245                panic!("--in-place requires exactly one input file")
246            }
247        } else {
248            self.output.clone()
249        }
250    }
251}
252
253async fn load_examples(
254    http_client: Arc<dyn http_client::HttpClient>,
255    args: &EpArgs,
256    output_path: Option<&PathBuf>,
257) -> anyhow::Result<Vec<Example>> {
258    let mut captured_after_timestamps = Vec::new();
259    let mut file_inputs = Vec::new();
260
261    for input in &args.inputs {
262        let input_string = input.to_string_lossy();
263        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
264            captured_after_timestamps.push(timestamp.to_string());
265        } else {
266            file_inputs.push(input.clone());
267        }
268    }
269
270    let mut examples = read_example_files(&file_inputs);
271
272    Progress::global().set_total_examples(examples.len());
273
274    let remaining_limit_for_snowflake =
275        args.limit.map(|limit| limit.saturating_sub(examples.len()));
276
277    if let Some(0) = remaining_limit_for_snowflake {
278        log::info!(
279            "skipping captured-after inputs because --limit is already satisfied by example files"
280        );
281    } else if !captured_after_timestamps.is_empty() {
282        captured_after_timestamps.sort();
283
284        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
285
286        let mut captured_examples = pull_examples::fetch_captured_examples_after(
287            http_client,
288            &captured_after_timestamps,
289            max_rows_per_timestamp,
290        )
291        .await?;
292        examples.append(&mut captured_examples);
293    }
294
295    crate::example::sort_examples_by_repo_and_rev(&mut examples);
296
297    if let Some(name_filter) = &args.name {
298        examples.retain(|example| example.spec.name.contains(name_filter));
299    }
300    if let Some(repo_filter) = &args.repo {
301        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
302    }
303
304    if let Some(limit) = args.limit {
305        if examples.len() > limit {
306            examples.truncate(limit);
307        }
308    }
309
310    if let Some(path) = output_path {
311        resume_from_output(path, &mut examples);
312    }
313
314    Progress::global().set_total_examples(examples.len());
315
316    Ok(examples)
317}
318
319fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
320    let mut hasher = collections::FxHasher::default();
321    spec.hash(&mut hasher);
322    hasher.finish()
323}
324
325fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
326    let file = match File::open(path) {
327        Ok(f) => f,
328        Err(_) => return,
329    };
330
331    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
332
333    let reader = BufReader::new(file);
334    let mut kept_lines = Vec::new();
335    let mut kept_hashes = HashSet::default();
336
337    for line in reader.lines() {
338        let line = match line {
339            Ok(l) => l,
340            Err(_) => continue,
341        };
342
343        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
344            let hash = spec_hash(&output_example.spec);
345            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
346                kept_hashes.insert(hash);
347                kept_lines.push(line);
348            }
349        }
350    }
351
352    let total = examples.len();
353    let already_processed = kept_hashes.len();
354
355    eprintln!(
356        "Resuming: {}/{} examples already processed",
357        already_processed, total
358    );
359
360    let file = OpenOptions::new()
361        .write(true)
362        .truncate(true)
363        .open(path)
364        .expect("Failed to open output file for rewriting");
365    let mut writer = BufWriter::new(file);
366    for line in &kept_lines {
367        writeln!(writer, "{}", line).expect("Failed to write to output file");
368    }
369    writer.flush().expect("Failed to flush output file");
370
371    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
372}
373
374fn main() {
375    let args = EpArgs::parse();
376
377    if args.printenv {
378        ::util::shell_env::print_env();
379        return;
380    }
381
382    let output = args.output_path();
383    let command = match &args.command {
384        Some(cmd) => cmd.clone(),
385        None => {
386            EpArgs::command().print_help().unwrap();
387            return;
388        }
389    };
390
391    match &command {
392        Command::Clean => {
393            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
394            return;
395        }
396        Command::Synthesize(synth_args) => {
397            let Some(output_dir) = args.output else {
398                panic!("output dir is required");
399            };
400            let config = SynthesizeConfig {
401                repo_url: synth_args.repo.clone(),
402                count: synth_args.count,
403                max_commits: synth_args.max_commits,
404                output_dir,
405                fresh: synth_args.fresh,
406            };
407            smol::block_on(async {
408                if let Err(e) = run_synthesize(config).await {
409                    eprintln!("Error: {:?}", e);
410                    std::process::exit(1);
411                }
412            });
413            return;
414        }
415        Command::SplitCommit(split_commit_args) => {
416            if let Err(error) =
417                split_commit::run_split_commit(split_commit_args, &args.inputs, output.as_ref())
418            {
419                eprintln!("{error:#}");
420                std::process::exit(1);
421            }
422            return;
423        }
424        Command::Split(split_args) => {
425            if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
426                eprintln!("{error:#}");
427                std::process::exit(1);
428            }
429            return;
430        }
431        _ => {}
432    }
433
434    let http_client = Arc::new(ReqwestClient::new());
435    let app = Application::headless().with_http_client(http_client);
436
437    app.run(move |cx| {
438        let app_state = Arc::new(headless::init(cx));
439        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
440
441        cx.spawn(async move |cx| {
442            let result = async {
443                let mut examples =
444                    load_examples(app_state.client.http_client(), &args, output.as_ref()).await?;
445
446                if let Command::Predict(args) = &command {
447                    predict::sync_batches(&args.provider).await?;
448                }
449
450                let failfast_on_single_example = examples.len() == 1;
451
452                let output_sender: Option<mpsc::UnboundedSender<String>> =
453                    if args.output.is_some() || !matches!(command, Command::Eval(_)) {
454                        output.as_ref().map(|path| {
455                            let file = OpenOptions::new()
456                                .create(true)
457                                .append(true)
458                                .open(path)
459                                .expect("Failed to open output file");
460                            let mut writer = BufWriter::new(file);
461                            let (sender, mut receiver) = mpsc::unbounded::<String>();
462                            cx.background_spawn(async move {
463                                while let Some(line) = receiver.next().await {
464                                    writeln!(writer, "{}", line).expect("Failed to write example");
465                                    writer.flush().expect("Failed to flush output");
466                                }
467                            })
468                            .detach();
469                            sender
470                        })
471                    } else {
472                        None
473                    };
474
475                let mut grouped_examples = group_examples_by_repo(&mut examples);
476                let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
477
478                for example_batch in example_batches {
479                    let futures = example_batch.into_iter().map(|repo_examples| async {
480                        for example in repo_examples.iter_mut() {
481                            let result = async {
482                                match &command {
483                                    Command::ParseExample => {}
484                                    Command::LoadProject => {
485                                        run_load_project(example, app_state.clone(), cx.clone())
486                                            .await?;
487                                    }
488                                    Command::Context => {
489                                        run_context_retrieval(
490                                            example,
491                                            app_state.clone(),
492                                            cx.clone(),
493                                        )
494                                        .await?;
495                                    }
496                                    Command::FormatPrompt(args) => {
497                                        run_format_prompt(
498                                            example,
499                                            args.prompt_format,
500                                            app_state.clone(),
501                                            cx.clone(),
502                                        )
503                                        .await?;
504                                    }
505                                    Command::Predict(args) => {
506                                        run_prediction(
507                                            example,
508                                            Some(args.provider),
509                                            args.repetitions,
510                                            app_state.clone(),
511                                            cx.clone(),
512                                        )
513                                        .await?;
514                                    }
515                                    Command::Distill => {
516                                        run_distill(example).await?;
517                                    }
518                                    Command::Score(args) | Command::Eval(args) => {
519                                        run_scoring(example, &args, app_state.clone(), cx.clone())
520                                            .await?;
521                                    }
522                                    Command::Clean
523                                    | Command::Synthesize(_)
524                                    | Command::SplitCommit(_)
525                                    | Command::Split(_) => {
526                                        unreachable!()
527                                    }
528                                }
529                                anyhow::Ok(())
530                            }
531                            .await;
532
533                            if let Err(error) = result {
534                                handle_error(
535                                    error,
536                                    &args,
537                                    &command,
538                                    &app_state,
539                                    failfast_on_single_example,
540                                    example,
541                                )
542                                .await;
543                            }
544
545                            if let Some(ref mut sender) = output_sender.clone() {
546                                let line = serde_json::to_string(example).unwrap();
547                                sender
548                                    .send(line)
549                                    .await
550                                    .expect("Failed to send to output writer");
551                            } else if args.output.is_none() && !matches!(command, Command::Eval(_))
552                            {
553                                let line = serde_json::to_string(example).unwrap();
554                                println!("{}", line);
555                            }
556                        }
557                    });
558                    futures::future::join_all(futures).await;
559                }
560
561                Progress::global().finalize();
562
563                match &command {
564                    Command::Predict(args) => predict::sync_batches(&args.provider).await?,
565                    Command::Eval(_) => score::print_report(&examples),
566                    _ => (),
567                };
568
569                anyhow::Ok(())
570            }
571            .await;
572
573            if let Err(e) = result {
574                panic!("Fatal error: {:?}", e);
575            }
576
577            let _ = cx.update(|cx| cx.quit());
578        })
579        .detach();
580    });
581}
582
583async fn handle_error(
584    error: anyhow::Error,
585    args: &EpArgs,
586    command: &Command,
587    app_state: &Arc<headless::EpAppState>,
588    failfast_on_single_example: bool,
589    example: &Example,
590) {
591    Progress::global().increment_failed();
592    let example_name = example.spec.filename();
593    let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
594    app_state
595        .fs
596        .write(
597            &failed_example_path,
598            &serde_json::to_vec_pretty(&example).unwrap(),
599        )
600        .await
601        .unwrap();
602    let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
603    app_state
604        .fs
605        .write(&err_path, format!("{error:?}").as_bytes())
606        .await
607        .unwrap();
608
609    let file_path = example
610        .repo_name()
611        .unwrap()
612        .worktree_path()
613        .join(&example.spec.cursor_path);
614
615    let msg = format!(
616        indoc::indoc! {"
617            While processing \"{}\":
618
619            \x1b[31m{:?}\x1b[0m
620
621            Example:        \x1b[36m{}\x1b[0m
622            Error file:     \x1b[36m{}\x1b[0m
623            Cursor file:    \x1b[36m{}\x1b[0m
624            Re-run:         cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
625        "},
626        example.spec.name,
627        error,
628        err_path.display(),
629        file_path.display(),
630        failed_example_path.display(),
631        command,
632        failed_example_path.display(),
633    );
634    if args.failfast || failfast_on_single_example {
635        Progress::global().finalize();
636        panic!("{}", msg);
637    } else {
638        log::error!("{}", msg);
639    }
640}