main.rs

  1mod anthropic_client;
  2mod distill;
  3mod example;
  4mod format_prompt;
  5mod git;
  6mod headless;
  7mod load_project;
  8mod metrics;
  9mod paths;
 10mod predict;
 11mod progress;
 12mod pull_examples;
 13mod reorder_patch;
 14mod retrieve_context;
 15mod score;
 16mod split_commit;
 17mod synthesize;
 18use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
 19use edit_prediction::EditPredictionStore;
 20use gpui::Application;
 21use reqwest_client::ReqwestClient;
 22use serde::{Deserialize, Serialize};
 23use std::fmt::Display;
 24use std::{path::PathBuf, sync::Arc};
 25
 26use crate::distill::run_distill;
 27use crate::example::{Example, group_examples_by_repo, read_example_files, write_examples};
 28use crate::format_prompt::run_format_prompt;
 29use crate::load_project::run_load_project;
 30use crate::paths::FAILED_EXAMPLES_DIR;
 31use crate::predict::run_prediction;
 32use crate::progress::Progress;
 33use crate::retrieve_context::run_context_retrieval;
 34use crate::score::run_scoring;
 35use crate::split_commit::SplitCommitArgs;
 36use crate::synthesize::{SynthesizeConfig, run_synthesize};
 37
 38#[derive(Parser, Debug)]
 39#[command(name = "ep")]
 40struct EpArgs {
 41    #[arg(long, default_value_t = false)]
 42    printenv: bool,
 43    #[clap(long, default_value_t = 10, global = true)]
 44    max_parallelism: usize,
 45    #[clap(long, global = true)]
 46    limit: Option<usize>,
 47    /// Filter examples by name
 48    #[clap(long, global = true)]
 49    name: Option<String>,
 50    /// Filter examples by repository
 51    #[clap(long, global = true)]
 52    repo: Option<String>,
 53    #[command(subcommand)]
 54    command: Option<Command>,
 55    #[clap(global = true, help = INPUTS_HELP)]
 56    inputs: Vec<PathBuf>,
 57    #[arg(long, short, global = true)]
 58    output: Option<PathBuf>,
 59    #[arg(long, short, global = true)]
 60    in_place: bool,
 61    #[arg(long, short, global = true)]
 62    failfast: bool,
 63}
 64
 65const INPUTS_HELP: &str = r#"
 66Inputs can be file paths or special specifiers:
 67
 68  path
 69      Path to an example(s) file (.md, .json, or .jsonl)
 70
 71  captured-after:{timestamp}
 72      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 73
 74      You can specify this multiple times and mix it with file inputs.
 75
 76      Required environment variables to connect to Snowflake:
 77          EP_SNOWFLAKE_API_KEY
 78          EP_SNOWFLAKE_BASE_URL
 79
 80      Optional:
 81          EP_SNOWFLAKE_ROLE
 82
 83Examples:
 84
 85  # Predict from a file
 86  ep predict examples.jsonl
 87
 88  # Predict from captured examples after a timestamp
 89  ep predict captured-after:2025-01-01T00:00:00Z
 90
 91  # Mix file inputs and captured-after in the same invocation
 92  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
 93"#;
 94
 95#[derive(Subcommand, Debug, Clone)]
 96enum Command {
 97    /// Parse markdown examples and output a combined .jsonl file
 98    ParseExample,
 99    /// Create git worktrees for each example and load file contents
100    LoadProject,
101    /// Retrieve context for input examples.
102    Context,
103    /// Generate a prompt string for a specific model
104    FormatPrompt(FormatPromptArgs),
105    /// Runs edit prediction
106    Predict(PredictArgs),
107    /// Computes a score based on actual and expected patches
108    Score(PredictArgs),
109    /// Prepares a distillation dataset by copying expected outputs to
110    /// predicted outputs and removing actual outputs and prompts.
111    Distill,
112    /// Print aggregated scores
113    Eval(PredictArgs),
114    /// Generate eval examples by analyzing git commits from a repository
115    Synthesize(SynthesizeArgs),
116    /// Remove git repositories and worktrees
117    Clean,
118    /// Generate an evaluation example by splitting a chronologically-ordered commit
119    SplitCommit(SplitCommitArgs),
120}
121
122impl Display for Command {
123    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124        match self {
125            Command::ParseExample => write!(f, "parse-example"),
126            Command::LoadProject => write!(f, "load-project"),
127            Command::Context => write!(f, "context"),
128            Command::FormatPrompt(format_prompt_args) => write!(
129                f,
130                "format-prompt --prompt-format={}",
131                format_prompt_args
132                    .prompt_format
133                    .to_possible_value()
134                    .unwrap()
135                    .get_name()
136            ),
137            Command::Predict(predict_args) => {
138                write!(
139                    f,
140                    "predict --provider={:?}",
141                    predict_args
142                        .provider
143                        .to_possible_value()
144                        .unwrap()
145                        .get_name()
146                )
147            }
148            Command::Score(predict_args) => {
149                write!(
150                    f,
151                    "score --provider={:?}",
152                    predict_args
153                        .provider
154                        .to_possible_value()
155                        .unwrap()
156                        .get_name()
157                )
158            }
159            Command::Distill => write!(f, "distill"),
160            Command::Eval(predict_args) => write!(
161                f,
162                "eval --provider={:?}",
163                predict_args
164                    .provider
165                    .to_possible_value()
166                    .unwrap()
167                    .get_name()
168            ),
169            Command::Synthesize(args) => {
170                write!(f, "synthesize --repo={}", args.repo)
171            }
172            Command::Clean => write!(f, "clean"),
173            Command::SplitCommit(_) => write!(f, "split-commit"),
174        }
175    }
176}
177
178#[derive(Debug, Args, Clone)]
179struct FormatPromptArgs {
180    #[clap(long)]
181    prompt_format: PromptFormat,
182}
183
184#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
185enum PromptFormat {
186    Teacher,
187    Zeta2,
188}
189
190#[derive(Debug, Args, Clone)]
191struct PredictArgs {
192    #[clap(long)]
193    provider: PredictionProvider,
194    #[clap(long, default_value_t = 1)]
195    repetitions: usize,
196}
197
198#[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)]
199enum PredictionProvider {
200    Sweep,
201    Mercury,
202    Zeta1,
203    Zeta2,
204    Teacher,
205    TeacherNonBatching,
206}
207
208#[derive(Debug, Args, Clone)]
209struct SynthesizeArgs {
210    /// Repository URL (git@github.com:owner/repo or https://...)
211    #[clap(long)]
212    repo: String,
213
214    /// Number of examples to generate
215    #[clap(long, default_value_t = 5)]
216    count: usize,
217
218    /// Maximum commits to scan before giving up
219    #[clap(long, default_value_t = 100)]
220    max_commits: usize,
221
222    /// Ignore state file and reprocess all commits
223    #[clap(long)]
224    fresh: bool,
225}
226
227impl EpArgs {
228    fn output_path(&self) -> Option<PathBuf> {
229        if self.in_place {
230            if self.inputs.len() == 1 {
231                self.inputs.first().cloned()
232            } else {
233                panic!("--in-place requires exactly one input file")
234            }
235        } else {
236            self.output.clone()
237        }
238    }
239}
240
241async fn load_examples(
242    http_client: Arc<dyn http_client::HttpClient>,
243    args: &EpArgs,
244) -> anyhow::Result<Vec<Example>> {
245    let mut captured_after_timestamps = Vec::new();
246    let mut file_inputs = Vec::new();
247
248    for input in &args.inputs {
249        let input_string = input.to_string_lossy();
250        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
251            captured_after_timestamps.push(timestamp.to_string());
252        } else {
253            file_inputs.push(input.clone());
254        }
255    }
256
257    let mut examples = read_example_files(&file_inputs);
258
259    let total_steps = examples.len() + captured_after_timestamps.len();
260    Progress::global().set_total_steps(total_steps);
261
262    let remaining_limit_for_snowflake =
263        args.limit.map(|limit| limit.saturating_sub(examples.len()));
264
265    if let Some(0) = remaining_limit_for_snowflake {
266        log::info!(
267            "skipping captured-after inputs because --limit is already satisfied by example files"
268        );
269    } else if !captured_after_timestamps.is_empty() {
270        captured_after_timestamps.sort();
271
272        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
273
274        let mut captured_examples = pull_examples::fetch_captured_examples_after(
275            http_client,
276            &captured_after_timestamps,
277            max_rows_per_timestamp,
278        )
279        .await?;
280        examples.append(&mut captured_examples);
281    }
282
283    crate::example::sort_examples_by_repo_and_rev(&mut examples);
284
285    if let Some(name_filter) = &args.name {
286        examples.retain(|example| example.spec.name.contains(name_filter));
287    }
288    if let Some(repo_filter) = &args.repo {
289        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
290    }
291
292    if let Some(limit) = args.limit {
293        if examples.len() > limit {
294            examples.truncate(limit);
295        }
296    }
297
298    Progress::global().set_total_steps(examples.len() + captured_after_timestamps.len());
299
300    Ok(examples)
301}
302
303fn main() {
304    let args = EpArgs::parse();
305
306    if args.printenv {
307        ::util::shell_env::print_env();
308        return;
309    }
310
311    let output = args.output_path();
312    let command = match &args.command {
313        Some(cmd) => cmd.clone(),
314        None => {
315            EpArgs::command().print_help().unwrap();
316            return;
317        }
318    };
319
320    match &command {
321        Command::Clean => {
322            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
323            return;
324        }
325        Command::Synthesize(synth_args) => {
326            let Some(output_dir) = args.output else {
327                panic!("output dir is required");
328            };
329            let config = SynthesizeConfig {
330                repo_url: synth_args.repo.clone(),
331                count: synth_args.count,
332                max_commits: synth_args.max_commits,
333                output_dir,
334                fresh: synth_args.fresh,
335            };
336            smol::block_on(async {
337                if let Err(e) = run_synthesize(config).await {
338                    eprintln!("Error: {:?}", e);
339                    std::process::exit(1);
340                }
341            });
342            return;
343        }
344        Command::SplitCommit(split_commit_args) => {
345            if let Err(error) =
346                split_commit::run_split_commit(split_commit_args, &args.inputs, output.as_ref())
347            {
348                eprintln!("{error:#}");
349                std::process::exit(1);
350            }
351            return;
352        }
353        _ => {}
354    }
355
356    let http_client = Arc::new(ReqwestClient::new());
357    let app = Application::headless().with_http_client(http_client);
358
359    app.run(move |cx| {
360        let app_state = Arc::new(headless::init(cx));
361        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
362
363        cx.spawn(async move |cx| {
364            let result = async {
365                let mut examples = load_examples(app_state.client.http_client(), &args).await?;
366
367                if let Command::Predict(args) = &command {
368                    predict::sync_batches(&args.provider).await?;
369                }
370
371                let failfast_on_single_example = examples.len() == 1;
372
373                let mut grouped_examples = group_examples_by_repo(&mut examples);
374                let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
375
376                for example_batch in example_batches {
377                    let futures = example_batch.into_iter().map(|repo_examples| async {
378                        for example in repo_examples.iter_mut() {
379                            let result = async {
380                                match &command {
381                                    Command::ParseExample => {}
382                                    Command::LoadProject => {
383                                        run_load_project(example, app_state.clone(), cx.clone())
384                                            .await?;
385                                    }
386                                    Command::Context => {
387                                        run_context_retrieval(
388                                            example,
389                                            app_state.clone(),
390                                            cx.clone(),
391                                        )
392                                        .await?;
393                                    }
394                                    Command::FormatPrompt(args) => {
395                                        run_format_prompt(
396                                            example,
397                                            args.prompt_format,
398                                            app_state.clone(),
399                                            cx.clone(),
400                                        )
401                                        .await?;
402                                    }
403                                    Command::Predict(args) => {
404                                        run_prediction(
405                                            example,
406                                            Some(args.provider),
407                                            args.repetitions,
408                                            app_state.clone(),
409                                            cx.clone(),
410                                        )
411                                        .await?;
412                                    }
413                                    Command::Distill => {
414                                        run_distill(example).await?;
415                                    }
416                                    Command::Score(args) | Command::Eval(args) => {
417                                        run_scoring(example, &args, app_state.clone(), cx.clone())
418                                            .await?;
419                                    }
420                                    Command::Clean
421                                    | Command::Synthesize(_)
422                                    | Command::SplitCommit(_) => {
423                                        unreachable!()
424                                    }
425                                }
426                                anyhow::Ok(())
427                            }
428                            .await;
429
430                            if let Err(error) = result {
431                                handle_error(
432                                    error,
433                                    &args,
434                                    &command,
435                                    &app_state,
436                                    failfast_on_single_example,
437                                    example,
438                                )
439                                .await;
440                            }
441                        }
442                    });
443                    futures::future::join_all(futures).await;
444                }
445                Progress::global().finalize();
446
447                if args.output.is_some() || !matches!(command, Command::Eval(_)) {
448                    write_examples(&examples, output.as_ref());
449                }
450
451                match &command {
452                    Command::Predict(args) => predict::sync_batches(&args.provider).await?,
453                    Command::Eval(_) => score::print_report(&examples),
454                    _ => (),
455                };
456
457                anyhow::Ok(())
458            }
459            .await;
460
461            if let Err(e) = result {
462                panic!("Fatal error: {:?}", e);
463            }
464
465            let _ = cx.update(|cx| cx.quit());
466        })
467        .detach();
468    });
469}
470
471async fn handle_error(
472    error: anyhow::Error,
473    args: &EpArgs,
474    command: &Command,
475    app_state: &Arc<headless::EpAppState>,
476    failfast_on_single_example: bool,
477    example: &Example,
478) {
479    Progress::global().increment_failed();
480    let example_name = example.spec.filename();
481    let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
482    app_state
483        .fs
484        .write(
485            &failed_example_path,
486            &serde_json::to_vec_pretty(&example).unwrap(),
487        )
488        .await
489        .unwrap();
490    let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
491    app_state
492        .fs
493        .write(&err_path, format!("{error:?}").as_bytes())
494        .await
495        .unwrap();
496
497    let file_path = example
498        .repo_name()
499        .unwrap()
500        .worktree_path()
501        .join(&example.spec.cursor_path);
502
503    let msg = format!(
504        indoc::indoc! {"
505            While processing \"{}\":
506
507            {:?}
508
509            Written to: \x1b[36m{}\x1b[0m
510
511            Cursor File: \x1b[36m{}\x1b[0m
512
513            Explore this example data with:
514            fx \x1b[36m{}\x1b[0m
515
516            Re-run this example with:
517            cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
518        "},
519        example.spec.name,
520        error,
521        err_path.display(),
522        file_path.display(),
523        failed_example_path.display(),
524        command,
525        failed_example_path.display(),
526    );
527    if args.failfast || failfast_on_single_example {
528        Progress::global().finalize();
529        panic!("{}", msg);
530    } else {
531        log::error!("{}", msg);
532    }
533}