main.rs

  1mod anthropic_client;
  2mod distill;
  3mod example;
  4mod format_prompt;
  5mod git;
  6mod headless;
  7mod load_project;
  8mod metrics;
  9mod paths;
 10mod predict;
 11mod progress;
 12mod pull_examples;
 13mod reorder_patch;
 14mod retrieve_context;
 15mod score;
 16mod split_commit;
 17mod synthesize;
 18use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
 19use edit_prediction::EditPredictionStore;
 20use gpui::Application;
 21use reqwest_client::ReqwestClient;
 22use serde::{Deserialize, Serialize};
 23use std::fmt::Display;
 24use std::{path::PathBuf, sync::Arc};
 25
 26use crate::distill::run_distill;
 27use crate::example::{Example, group_examples_by_repo, read_example_files, write_examples};
 28use crate::format_prompt::run_format_prompt;
 29use crate::load_project::run_load_project;
 30use crate::paths::FAILED_EXAMPLES_DIR;
 31use crate::predict::run_prediction;
 32use crate::progress::Progress;
 33use crate::retrieve_context::run_context_retrieval;
 34use crate::score::run_scoring;
 35use crate::split_commit::SplitCommitArgs;
 36use crate::synthesize::{SynthesizeConfig, run_synthesize};
 37
 38#[derive(Parser, Debug)]
 39#[command(name = "ep")]
 40struct EpArgs {
 41    #[arg(long, default_value_t = false)]
 42    printenv: bool,
 43    #[clap(long, default_value_t = 10, global = true)]
 44    max_parallelism: usize,
 45    #[clap(long, global = true)]
 46    limit: Option<usize>,
 47    #[command(subcommand)]
 48    command: Option<Command>,
 49    #[clap(global = true, help = INPUTS_HELP)]
 50    inputs: Vec<PathBuf>,
 51    #[arg(long, short, global = true)]
 52    output: Option<PathBuf>,
 53    #[arg(long, short, global = true)]
 54    in_place: bool,
 55    #[arg(long, short, global = true)]
 56    failfast: bool,
 57}
 58
 59const INPUTS_HELP: &str = r#"
 60Inputs can be file paths or special specifiers:
 61
 62  path
 63      Path to an example(s) file (.md, .json, or .jsonl)
 64
 65  captured-after:{timestamp}
 66      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 67
 68      You can specify this multiple times and mix it with file inputs.
 69
 70      Required environment variables to connect to Snowflake:
 71          EP_SNOWFLAKE_API_KEY
 72          EP_SNOWFLAKE_BASE_URL
 73
 74      Optional:
 75          EP_SNOWFLAKE_ROLE
 76
 77Examples:
 78
 79  # Predict from a file
 80  ep predict examples.jsonl
 81
 82  # Predict from captured examples after a timestamp
 83  ep predict captured-after:2025-01-01T00:00:00Z
 84
 85  # Mix file inputs and captured-after in the same invocation
 86  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
 87"#;
 88
 89#[derive(Subcommand, Debug, Clone)]
 90enum Command {
 91    /// Parse markdown examples and output a combined .jsonl file
 92    ParseExample,
 93    /// Create git worktrees for each example and load file contents
 94    LoadProject,
 95    /// Retrieve context for input examples.
 96    Context,
 97    /// Generate a prompt string for a specific model
 98    FormatPrompt(FormatPromptArgs),
 99    /// Runs edit prediction
100    Predict(PredictArgs),
101    /// Computes a score based on actual and expected patches
102    Score(PredictArgs),
103    /// Prepares a distillation dataset by copying expected outputs to
104    /// predicted outputs and removing actual outputs and prompts.
105    Distill,
106    /// Print aggregated scores
107    Eval(PredictArgs),
108    /// Generate eval examples by analyzing git commits from a repository
109    Synthesize(SynthesizeArgs),
110    /// Remove git repositories and worktrees
111    Clean,
112    /// Generate an evaluation example by splitting a chronologically-ordered commit
113    SplitCommit(SplitCommitArgs),
114}
115
116impl Display for Command {
117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118        match self {
119            Command::ParseExample => write!(f, "parse-example"),
120            Command::LoadProject => write!(f, "load-project"),
121            Command::Context => write!(f, "context"),
122            Command::FormatPrompt(format_prompt_args) => write!(
123                f,
124                "format-prompt --prompt-format={}",
125                format_prompt_args
126                    .prompt_format
127                    .to_possible_value()
128                    .unwrap()
129                    .get_name()
130            ),
131            Command::Predict(predict_args) => {
132                write!(
133                    f,
134                    "predict --provider={:?}",
135                    predict_args
136                        .provider
137                        .to_possible_value()
138                        .unwrap()
139                        .get_name()
140                )
141            }
142            Command::Score(predict_args) => {
143                write!(
144                    f,
145                    "score --provider={:?}",
146                    predict_args
147                        .provider
148                        .to_possible_value()
149                        .unwrap()
150                        .get_name()
151                )
152            }
153            Command::Distill => write!(f, "distill"),
154            Command::Eval(predict_args) => write!(
155                f,
156                "eval --provider={:?}",
157                predict_args
158                    .provider
159                    .to_possible_value()
160                    .unwrap()
161                    .get_name()
162            ),
163            Command::Synthesize(args) => {
164                write!(f, "synthesize --repo={}", args.repo)
165            }
166            Command::Clean => write!(f, "clean"),
167            Command::SplitCommit(_) => write!(f, "split-commit"),
168        }
169    }
170}
171
172#[derive(Debug, Args, Clone)]
173struct FormatPromptArgs {
174    #[clap(long)]
175    prompt_format: PromptFormat,
176}
177
178#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
179enum PromptFormat {
180    Teacher,
181    Zeta2,
182}
183
184#[derive(Debug, Args, Clone)]
185struct PredictArgs {
186    #[clap(long)]
187    provider: PredictionProvider,
188    #[clap(long, default_value_t = 1)]
189    repetitions: usize,
190}
191
192#[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)]
193enum PredictionProvider {
194    Sweep,
195    Mercury,
196    Zeta1,
197    Zeta2,
198    Teacher,
199    TeacherNonBatching,
200}
201
202#[derive(Debug, Args, Clone)]
203struct SynthesizeArgs {
204    /// Repository URL (git@github.com:owner/repo or https://...)
205    #[clap(long)]
206    repo: String,
207
208    /// Number of examples to generate
209    #[clap(long, default_value_t = 5)]
210    count: usize,
211
212    /// Maximum commits to scan before giving up
213    #[clap(long, default_value_t = 100)]
214    max_commits: usize,
215
216    /// Ignore state file and reprocess all commits
217    #[clap(long)]
218    fresh: bool,
219}
220
221impl EpArgs {
222    fn output_path(&self) -> Option<PathBuf> {
223        if self.in_place {
224            if self.inputs.len() == 1 {
225                self.inputs.first().cloned()
226            } else {
227                panic!("--in-place requires exactly one input file")
228            }
229        } else {
230            self.output.clone()
231        }
232    }
233}
234
235async fn load_examples(
236    http_client: Arc<dyn http_client::HttpClient>,
237    args: &EpArgs,
238) -> anyhow::Result<Vec<Example>> {
239    let mut captured_after_timestamps = Vec::new();
240    let mut file_inputs = Vec::new();
241
242    for input in &args.inputs {
243        let input_string = input.to_string_lossy();
244        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
245            captured_after_timestamps.push(timestamp.to_string());
246        } else {
247            file_inputs.push(input.clone());
248        }
249    }
250
251    let mut examples = read_example_files(&file_inputs);
252    let total_steps = examples.len() + captured_after_timestamps.len();
253    Progress::global().set_total_steps(total_steps);
254
255    let remaining_limit_for_snowflake =
256        args.limit.map(|limit| limit.saturating_sub(examples.len()));
257
258    if let Some(0) = remaining_limit_for_snowflake {
259        log::info!(
260            "skipping captured-after inputs because --limit is already satisfied by example files"
261        );
262    } else if !captured_after_timestamps.is_empty() {
263        captured_after_timestamps.sort();
264
265        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
266
267        let mut captured_examples = pull_examples::fetch_captured_examples_after(
268            http_client,
269            &captured_after_timestamps,
270            max_rows_per_timestamp,
271        )
272        .await?;
273        examples.append(&mut captured_examples);
274    }
275
276    crate::example::sort_examples_by_repo_and_rev(&mut examples);
277
278    if let Some(limit) = args.limit {
279        if examples.len() > limit {
280            examples.truncate(limit);
281        }
282    }
283
284    Progress::global().set_total_steps(examples.len() + captured_after_timestamps.len());
285
286    Ok(examples)
287}
288
289fn main() {
290    let args = EpArgs::parse();
291
292    if args.printenv {
293        ::util::shell_env::print_env();
294        return;
295    }
296
297    let output = args.output_path();
298    let command = match &args.command {
299        Some(cmd) => cmd.clone(),
300        None => {
301            EpArgs::command().print_help().unwrap();
302            return;
303        }
304    };
305
306    match &command {
307        Command::Clean => {
308            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
309            return;
310        }
311        Command::Synthesize(synth_args) => {
312            let Some(output_dir) = args.output else {
313                panic!("output dir is required");
314            };
315            let config = SynthesizeConfig {
316                repo_url: synth_args.repo.clone(),
317                count: synth_args.count,
318                max_commits: synth_args.max_commits,
319                output_dir,
320                fresh: synth_args.fresh,
321            };
322            smol::block_on(async {
323                if let Err(e) = run_synthesize(config).await {
324                    eprintln!("Error: {:?}", e);
325                    std::process::exit(1);
326                }
327            });
328            return;
329        }
330        Command::SplitCommit(split_commit_args) => {
331            if let Err(error) =
332                split_commit::run_split_commit(split_commit_args, &args.inputs, output.as_ref())
333            {
334                eprintln!("{error:#}");
335                std::process::exit(1);
336            }
337            return;
338        }
339        _ => {}
340    }
341
342    let http_client = Arc::new(ReqwestClient::new());
343    let app = Application::headless().with_http_client(http_client);
344
345    app.run(move |cx| {
346        let app_state = Arc::new(headless::init(cx));
347        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
348
349        cx.spawn(async move |cx| {
350            let result = async {
351                let mut examples = load_examples(app_state.client.http_client(), &args).await?;
352
353                if let Command::Predict(args) = &command {
354                    predict::sync_batches(&args.provider).await?;
355                }
356
357                let failfast_on_single_example = examples.len() == 1;
358
359                let mut grouped_examples = group_examples_by_repo(&mut examples);
360                let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
361
362                for example_batch in example_batches {
363                    let futures = example_batch.into_iter().map(|repo_examples| async {
364                        for example in repo_examples.iter_mut() {
365                            let result = async {
366                                match &command {
367                                    Command::ParseExample => {}
368                                    Command::LoadProject => {
369                                        run_load_project(example, app_state.clone(), cx.clone())
370                                            .await?;
371                                    }
372                                    Command::Context => {
373                                        run_context_retrieval(
374                                            example,
375                                            app_state.clone(),
376                                            cx.clone(),
377                                        )
378                                        .await?;
379                                    }
380                                    Command::FormatPrompt(args) => {
381                                        run_format_prompt(
382                                            example,
383                                            args.prompt_format,
384                                            app_state.clone(),
385                                            cx.clone(),
386                                        )
387                                        .await?;
388                                    }
389                                    Command::Predict(args) => {
390                                        run_prediction(
391                                            example,
392                                            Some(args.provider),
393                                            args.repetitions,
394                                            app_state.clone(),
395                                            cx.clone(),
396                                        )
397                                        .await?;
398                                    }
399                                    Command::Distill => {
400                                        run_distill(example).await?;
401                                    }
402                                    Command::Score(args) | Command::Eval(args) => {
403                                        run_scoring(example, &args, app_state.clone(), cx.clone())
404                                            .await?;
405                                    }
406                                    Command::Clean
407                                    | Command::Synthesize(_)
408                                    | Command::SplitCommit(_) => {
409                                        unreachable!()
410                                    }
411                                }
412                                anyhow::Ok(())
413                            }
414                            .await;
415
416                            if let Err(e) = result {
417                                Progress::global().increment_failed();
418                                let failed_example_path =
419                                    FAILED_EXAMPLES_DIR.join(format!("{}.json", example.spec.name));
420                                app_state
421                                    .fs
422                                    .write(
423                                        &failed_example_path,
424                                        &serde_json::to_vec_pretty(&example).unwrap(),
425                                    )
426                                    .await
427                                    .unwrap();
428                                let err_path = FAILED_EXAMPLES_DIR
429                                    .join(format!("{}_err.txt", example.spec.name));
430                                app_state
431                                    .fs
432                                    .write(&err_path, format!("{e:?}").as_bytes())
433                                    .await
434                                    .unwrap();
435
436                                let msg = format!(
437                                    indoc::indoc! {"
438                                        While processing \"{}\":
439
440                                        {:?}
441
442                                        Written to: \x1b[36m{}\x1b[0m
443
444                                        Explore this example data with:
445                                            fx \x1b[36m{}\x1b[0m
446
447                                        Re-run this example with:
448                                            cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
449                                    "},
450                                    example.spec.name,
451                                    e,
452                                    err_path.display(),
453                                    failed_example_path.display(),
454                                    command,
455                                    failed_example_path.display(),
456                                );
457                                if args.failfast || failfast_on_single_example {
458                                    Progress::global().finalize();
459                                    panic!("{}", msg);
460                                } else {
461                                    log::error!("{}", msg);
462                                }
463                            }
464                        }
465                    });
466                    futures::future::join_all(futures).await;
467                }
468                Progress::global().finalize();
469
470                if args.output.is_some() || !matches!(command, Command::Eval(_)) {
471                    write_examples(&examples, output.as_ref());
472                }
473
474                match &command {
475                    Command::Predict(args) => predict::sync_batches(&args.provider).await?,
476                    Command::Eval(_) => score::print_report(&examples),
477                    _ => (),
478                };
479
480                anyhow::Ok(())
481            }
482            .await;
483
484            if let Err(e) = result {
485                panic!("Fatal error: {:?}", e);
486            }
487
488            let _ = cx.update(|cx| cx.quit());
489        })
490        .detach();
491    });
492}