main.rs

  1mod anthropic_client;
  2mod distill;
  3mod example;
  4mod format_prompt;
  5mod git;
  6mod headless;
  7mod load_project;
  8mod metrics;
  9mod paths;
 10mod predict;
 11mod progress;
 12mod pull_examples;
 13mod reorder_patch;
 14mod retrieve_context;
 15mod score;
 16mod split_commit;
 17mod synthesize;
 18use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
 19use collections::HashSet;
 20use edit_prediction::EditPredictionStore;
 21use futures::channel::mpsc;
 22use futures::{SinkExt as _, StreamExt as _};
 23use gpui::{AppContext as _, Application};
 24
 25use reqwest_client::ReqwestClient;
 26use serde::{Deserialize, Serialize};
 27use std::fmt::Display;
 28use std::fs::{File, OpenOptions};
 29use std::hash::{Hash, Hasher};
 30use std::io::{BufRead, BufReader, BufWriter, Write};
 31use std::{path::PathBuf, sync::Arc};
 32
 33use crate::distill::run_distill;
 34use crate::example::{Example, group_examples_by_repo, read_example_files};
 35use crate::format_prompt::run_format_prompt;
 36use crate::load_project::run_load_project;
 37use crate::paths::FAILED_EXAMPLES_DIR;
 38use crate::predict::run_prediction;
 39use crate::progress::Progress;
 40use crate::retrieve_context::run_context_retrieval;
 41use crate::score::run_scoring;
 42use crate::split_commit::SplitCommitArgs;
 43use crate::synthesize::{SynthesizeConfig, run_synthesize};
 44
 45#[derive(Parser, Debug)]
 46#[command(name = "ep")]
 47struct EpArgs {
 48    #[arg(long, default_value_t = false)]
 49    printenv: bool,
 50    #[clap(long, default_value_t = 10, global = true)]
 51    max_parallelism: usize,
 52    #[clap(long, global = true)]
 53    limit: Option<usize>,
 54    /// Filter examples by name
 55    #[clap(long, global = true)]
 56    name: Option<String>,
 57    /// Filter examples by repository
 58    #[clap(long, global = true)]
 59    repo: Option<String>,
 60    #[command(subcommand)]
 61    command: Option<Command>,
 62    #[clap(global = true, help = INPUTS_HELP)]
 63    inputs: Vec<PathBuf>,
 64    #[arg(long, short, global = true)]
 65    output: Option<PathBuf>,
 66    #[arg(long, short, global = true)]
 67    in_place: bool,
 68    #[arg(long, short, global = true)]
 69    failfast: bool,
 70}
 71
 72const INPUTS_HELP: &str = r#"
 73Inputs can be file paths or special specifiers:
 74
 75  path
 76      Path to an example(s) file (.md, .json, or .jsonl)
 77
 78  captured-after:{timestamp}
 79      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 80
 81      You can specify this multiple times and mix it with file inputs.
 82
 83      Required environment variables to connect to Snowflake:
 84          EP_SNOWFLAKE_API_KEY
 85          EP_SNOWFLAKE_BASE_URL
 86
 87      Optional:
 88          EP_SNOWFLAKE_ROLE
 89
 90Examples:
 91
 92  # Predict from a file
 93  ep predict examples.jsonl
 94
 95  # Predict from captured examples after a timestamp
 96  ep predict captured-after:2025-01-01T00:00:00Z
 97
 98  # Mix file inputs and captured-after in the same invocation
 99  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
100"#;
101
102#[derive(Subcommand, Debug, Clone)]
103enum Command {
104    /// Parse markdown examples and output a combined .jsonl file
105    ParseExample,
106    /// Create git worktrees for each example and load file contents
107    LoadProject,
108    /// Retrieve context for input examples.
109    Context,
110    /// Generate a prompt string for a specific model
111    FormatPrompt(FormatPromptArgs),
112    /// Runs edit prediction
113    Predict(PredictArgs),
114    /// Computes a score based on actual and expected patches
115    Score(PredictArgs),
116    /// Prepares a distillation dataset by copying expected outputs to
117    /// predicted outputs and removing actual outputs and prompts.
118    Distill,
119    /// Print aggregated scores
120    Eval(PredictArgs),
121    /// Generate eval examples by analyzing git commits from a repository
122    Synthesize(SynthesizeArgs),
123    /// Remove git repositories and worktrees
124    Clean,
125    /// Generate an evaluation example by splitting a chronologically-ordered commit
126    SplitCommit(SplitCommitArgs),
127}
128
129impl Display for Command {
130    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131        match self {
132            Command::ParseExample => write!(f, "parse-example"),
133            Command::LoadProject => write!(f, "load-project"),
134            Command::Context => write!(f, "context"),
135            Command::FormatPrompt(format_prompt_args) => write!(
136                f,
137                "format-prompt --prompt-format={}",
138                format_prompt_args
139                    .prompt_format
140                    .to_possible_value()
141                    .unwrap()
142                    .get_name()
143            ),
144            Command::Predict(predict_args) => {
145                write!(
146                    f,
147                    "predict --provider={:?}",
148                    predict_args
149                        .provider
150                        .to_possible_value()
151                        .unwrap()
152                        .get_name()
153                )
154            }
155            Command::Score(predict_args) => {
156                write!(
157                    f,
158                    "score --provider={:?}",
159                    predict_args
160                        .provider
161                        .to_possible_value()
162                        .unwrap()
163                        .get_name()
164                )
165            }
166            Command::Distill => write!(f, "distill"),
167            Command::Eval(predict_args) => write!(
168                f,
169                "eval --provider={:?}",
170                predict_args
171                    .provider
172                    .to_possible_value()
173                    .unwrap()
174                    .get_name()
175            ),
176            Command::Synthesize(args) => {
177                write!(f, "synthesize --repo={}", args.repo)
178            }
179            Command::Clean => write!(f, "clean"),
180            Command::SplitCommit(_) => write!(f, "split-commit"),
181        }
182    }
183}
184
185#[derive(Debug, Args, Clone)]
186struct FormatPromptArgs {
187    #[clap(long)]
188    prompt_format: PromptFormat,
189}
190
191#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
192enum PromptFormat {
193    Teacher,
194    Zeta2,
195}
196
197#[derive(Debug, Args, Clone)]
198struct PredictArgs {
199    #[clap(long)]
200    provider: PredictionProvider,
201    #[clap(long, default_value_t = 1)]
202    repetitions: usize,
203}
204
205#[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)]
206enum PredictionProvider {
207    Sweep,
208    Mercury,
209    Zeta1,
210    Zeta2,
211    Teacher,
212    TeacherNonBatching,
213}
214
215#[derive(Debug, Args, Clone)]
216struct SynthesizeArgs {
217    /// Repository URL (git@github.com:owner/repo or https://...)
218    #[clap(long)]
219    repo: String,
220
221    /// Number of examples to generate
222    #[clap(long, default_value_t = 5)]
223    count: usize,
224
225    /// Maximum commits to scan before giving up
226    #[clap(long, default_value_t = 100)]
227    max_commits: usize,
228
229    /// Ignore state file and reprocess all commits
230    #[clap(long)]
231    fresh: bool,
232}
233
234impl EpArgs {
235    fn output_path(&self) -> Option<PathBuf> {
236        if self.in_place {
237            if self.inputs.len() == 1 {
238                self.inputs.first().cloned()
239            } else {
240                panic!("--in-place requires exactly one input file")
241            }
242        } else {
243            self.output.clone()
244        }
245    }
246}
247
248async fn load_examples(
249    http_client: Arc<dyn http_client::HttpClient>,
250    args: &EpArgs,
251    output_path: Option<&PathBuf>,
252) -> anyhow::Result<Vec<Example>> {
253    let mut captured_after_timestamps = Vec::new();
254    let mut file_inputs = Vec::new();
255
256    for input in &args.inputs {
257        let input_string = input.to_string_lossy();
258        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
259            captured_after_timestamps.push(timestamp.to_string());
260        } else {
261            file_inputs.push(input.clone());
262        }
263    }
264
265    let mut examples = read_example_files(&file_inputs);
266
267    Progress::global().set_total_examples(examples.len());
268
269    let remaining_limit_for_snowflake =
270        args.limit.map(|limit| limit.saturating_sub(examples.len()));
271
272    if let Some(0) = remaining_limit_for_snowflake {
273        log::info!(
274            "skipping captured-after inputs because --limit is already satisfied by example files"
275        );
276    } else if !captured_after_timestamps.is_empty() {
277        captured_after_timestamps.sort();
278
279        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
280
281        let mut captured_examples = pull_examples::fetch_captured_examples_after(
282            http_client,
283            &captured_after_timestamps,
284            max_rows_per_timestamp,
285        )
286        .await?;
287        examples.append(&mut captured_examples);
288    }
289
290    crate::example::sort_examples_by_repo_and_rev(&mut examples);
291
292    if let Some(name_filter) = &args.name {
293        examples.retain(|example| example.spec.name.contains(name_filter));
294    }
295    if let Some(repo_filter) = &args.repo {
296        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
297    }
298
299    if let Some(limit) = args.limit {
300        if examples.len() > limit {
301            examples.truncate(limit);
302        }
303    }
304
305    if let Some(path) = output_path {
306        resume_from_output(path, &mut examples);
307    }
308
309    Progress::global().set_total_examples(examples.len());
310
311    Ok(examples)
312}
313
314fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
315    let mut hasher = collections::FxHasher::default();
316    spec.hash(&mut hasher);
317    hasher.finish()
318}
319
320fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
321    let file = match File::open(path) {
322        Ok(f) => f,
323        Err(_) => return,
324    };
325
326    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
327
328    let reader = BufReader::new(file);
329    let mut kept_lines = Vec::new();
330    let mut kept_hashes = HashSet::default();
331
332    for line in reader.lines() {
333        let line = match line {
334            Ok(l) => l,
335            Err(_) => continue,
336        };
337
338        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
339            let hash = spec_hash(&output_example.spec);
340            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
341                kept_hashes.insert(hash);
342                kept_lines.push(line);
343            }
344        }
345    }
346
347    let total = examples.len();
348    let already_processed = kept_hashes.len();
349
350    eprintln!(
351        "Resuming: {}/{} examples already processed",
352        already_processed, total
353    );
354
355    let file = OpenOptions::new()
356        .write(true)
357        .truncate(true)
358        .open(path)
359        .expect("Failed to open output file for rewriting");
360    let mut writer = BufWriter::new(file);
361    for line in &kept_lines {
362        writeln!(writer, "{}", line).expect("Failed to write to output file");
363    }
364    writer.flush().expect("Failed to flush output file");
365
366    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
367}
368
369fn main() {
370    let args = EpArgs::parse();
371
372    if args.printenv {
373        ::util::shell_env::print_env();
374        return;
375    }
376
377    let output = args.output_path();
378    let command = match &args.command {
379        Some(cmd) => cmd.clone(),
380        None => {
381            EpArgs::command().print_help().unwrap();
382            return;
383        }
384    };
385
386    match &command {
387        Command::Clean => {
388            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
389            return;
390        }
391        Command::Synthesize(synth_args) => {
392            let Some(output_dir) = args.output else {
393                panic!("output dir is required");
394            };
395            let config = SynthesizeConfig {
396                repo_url: synth_args.repo.clone(),
397                count: synth_args.count,
398                max_commits: synth_args.max_commits,
399                output_dir,
400                fresh: synth_args.fresh,
401            };
402            smol::block_on(async {
403                if let Err(e) = run_synthesize(config).await {
404                    eprintln!("Error: {:?}", e);
405                    std::process::exit(1);
406                }
407            });
408            return;
409        }
410        Command::SplitCommit(split_commit_args) => {
411            if let Err(error) =
412                split_commit::run_split_commit(split_commit_args, &args.inputs, output.as_ref())
413            {
414                eprintln!("{error:#}");
415                std::process::exit(1);
416            }
417            return;
418        }
419        _ => {}
420    }
421
422    let http_client = Arc::new(ReqwestClient::new());
423    let app = Application::headless().with_http_client(http_client);
424
425    app.run(move |cx| {
426        let app_state = Arc::new(headless::init(cx));
427        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
428
429        cx.spawn(async move |cx| {
430            let result = async {
431                let mut examples =
432                    load_examples(app_state.client.http_client(), &args, output.as_ref()).await?;
433
434                if let Command::Predict(args) = &command {
435                    predict::sync_batches(&args.provider).await?;
436                }
437
438                let failfast_on_single_example = examples.len() == 1;
439
440                let output_sender: Option<mpsc::UnboundedSender<String>> =
441                    if args.output.is_some() || !matches!(command, Command::Eval(_)) {
442                        output.as_ref().map(|path| {
443                            let file = OpenOptions::new()
444                                .create(true)
445                                .append(true)
446                                .open(path)
447                                .expect("Failed to open output file");
448                            let mut writer = BufWriter::new(file);
449                            let (sender, mut receiver) = mpsc::unbounded::<String>();
450                            cx.background_spawn(async move {
451                                while let Some(line) = receiver.next().await {
452                                    writeln!(writer, "{}", line).expect("Failed to write example");
453                                    writer.flush().expect("Failed to flush output");
454                                }
455                            })
456                            .detach();
457                            sender
458                        })
459                    } else {
460                        None
461                    };
462
463                let mut grouped_examples = group_examples_by_repo(&mut examples);
464                let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
465
466                for example_batch in example_batches {
467                    let futures = example_batch.into_iter().map(|repo_examples| async {
468                        for example in repo_examples.iter_mut() {
469                            let result = async {
470                                match &command {
471                                    Command::ParseExample => {}
472                                    Command::LoadProject => {
473                                        run_load_project(example, app_state.clone(), cx.clone())
474                                            .await?;
475                                    }
476                                    Command::Context => {
477                                        run_context_retrieval(
478                                            example,
479                                            app_state.clone(),
480                                            cx.clone(),
481                                        )
482                                        .await?;
483                                    }
484                                    Command::FormatPrompt(args) => {
485                                        run_format_prompt(
486                                            example,
487                                            args.prompt_format,
488                                            app_state.clone(),
489                                            cx.clone(),
490                                        )
491                                        .await?;
492                                    }
493                                    Command::Predict(args) => {
494                                        run_prediction(
495                                            example,
496                                            Some(args.provider),
497                                            args.repetitions,
498                                            app_state.clone(),
499                                            cx.clone(),
500                                        )
501                                        .await?;
502                                    }
503                                    Command::Distill => {
504                                        run_distill(example).await?;
505                                    }
506                                    Command::Score(args) | Command::Eval(args) => {
507                                        run_scoring(example, &args, app_state.clone(), cx.clone())
508                                            .await?;
509                                    }
510                                    Command::Clean
511                                    | Command::Synthesize(_)
512                                    | Command::SplitCommit(_) => {
513                                        unreachable!()
514                                    }
515                                }
516                                anyhow::Ok(())
517                            }
518                            .await;
519
520                            if let Err(error) = result {
521                                handle_error(
522                                    error,
523                                    &args,
524                                    &command,
525                                    &app_state,
526                                    failfast_on_single_example,
527                                    example,
528                                )
529                                .await;
530                            }
531
532                            if let Some(ref mut sender) = output_sender.clone() {
533                                let line = serde_json::to_string(example).unwrap();
534                                sender
535                                    .send(line)
536                                    .await
537                                    .expect("Failed to send to output writer");
538                            } else if args.output.is_none() && !matches!(command, Command::Eval(_))
539                            {
540                                let line = serde_json::to_string(example).unwrap();
541                                println!("{}", line);
542                            }
543                        }
544                    });
545                    futures::future::join_all(futures).await;
546                }
547
548                Progress::global().finalize();
549
550                match &command {
551                    Command::Predict(args) => predict::sync_batches(&args.provider).await?,
552                    Command::Eval(_) => score::print_report(&examples),
553                    _ => (),
554                };
555
556                anyhow::Ok(())
557            }
558            .await;
559
560            if let Err(e) = result {
561                panic!("Fatal error: {:?}", e);
562            }
563
564            let _ = cx.update(|cx| cx.quit());
565        })
566        .detach();
567    });
568}
569
570async fn handle_error(
571    error: anyhow::Error,
572    args: &EpArgs,
573    command: &Command,
574    app_state: &Arc<headless::EpAppState>,
575    failfast_on_single_example: bool,
576    example: &Example,
577) {
578    Progress::global().increment_failed();
579    let example_name = example.spec.filename();
580    let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
581    app_state
582        .fs
583        .write(
584            &failed_example_path,
585            &serde_json::to_vec_pretty(&example).unwrap(),
586        )
587        .await
588        .unwrap();
589    let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
590    app_state
591        .fs
592        .write(&err_path, format!("{error:?}").as_bytes())
593        .await
594        .unwrap();
595
596    let file_path = example
597        .repo_name()
598        .unwrap()
599        .worktree_path()
600        .join(&example.spec.cursor_path);
601
602    let msg = format!(
603        indoc::indoc! {"
604            While processing \"{}\":
605
606            {:?}
607
608            Written to: \x1b[36m{}\x1b[0m
609
610            Cursor File: \x1b[36m{}\x1b[0m
611
612            Explore this example data with:
613            fx \x1b[36m{}\x1b[0m
614
615            Re-run this example with:
616            cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
617        "},
618        example.spec.name,
619        error,
620        err_path.display(),
621        file_path.display(),
622        failed_example_path.display(),
623        command,
624        failed_example_path.display(),
625    );
626    if args.failfast || failfast_on_single_example {
627        Progress::global().finalize();
628        panic!("{}", msg);
629    } else {
630        log::error!("{}", msg);
631    }
632}