Allow EP synthesize command to take multiple repos (#46853)

Max Brunsfeld and Agus Zubiaga created

Release Notes:

- N/A

Co-authored-by: Agus Zubiaga <agus@zed.dev>

Change summary

crates/edit_prediction_cli/Cargo.toml          |   1 
crates/edit_prediction_cli/src/load_project.rs |   6 
crates/edit_prediction_cli/src/main.rs         |  14 +-
crates/edit_prediction_cli/src/synthesize.rs   | 137 ++++++++++++++-----
4 files changed, 114 insertions(+), 44 deletions(-)

Detailed changes

crates/edit_prediction_cli/Cargo.toml 🔗

@@ -37,6 +37,7 @@ languages = { workspace = true, features = ["load-grammars"] }
 libc.workspace = true
 log.workspace = true
 node_runtime.workspace = true
+
 paths.workspace = true
 project.workspace = true
 prompt_store.workspace = true

crates/edit_prediction_cli/src/load_project.rs 🔗

@@ -5,9 +5,11 @@ use crate::{
     progress::{InfoStyle, Progress, Step, StepProgress},
 };
 use anyhow::{Context as _, Result};
-use edit_prediction::udiff::{OpenedBuffers, refresh_worktree_entries, strip_diff_path_prefix};
 use edit_prediction::{
-    EditPredictionStore, cursor_excerpt::editable_and_context_ranges_for_cursor_position, zeta2,
+    EditPredictionStore,
+    cursor_excerpt::editable_and_context_ranges_for_cursor_position,
+    udiff::{OpenedBuffers, refresh_worktree_entries, strip_diff_path_prefix},
+    zeta2,
 };
 use futures::AsyncWriteExt as _;
 use gpui::{AsyncApp, Entity};

crates/edit_prediction_cli/src/main.rs 🔗

@@ -194,7 +194,7 @@ impl Display for Command {
                     .get_name()
             ),
             Command::Synthesize(args) => {
-                write!(f, "synthesize --repo={}", args.repo)
+                write!(f, "synthesize --repos {}", args.repos.join(" "))
             }
             Command::Clean => write!(f, "clean"),
             Command::SplitCommit(_) => write!(f, "split-commit"),
@@ -244,15 +244,15 @@ enum PredictionProvider {
 
 #[derive(Debug, Args, Clone)]
 struct SynthesizeArgs {
-    /// Repository URL (git@github.com:owner/repo or https://...)
-    #[clap(long)]
-    repo: String,
+    /// Repository URLs (git@github.com:owner/repo or https://...)
+    #[clap(long, required = true, num_args = 1..)]
+    repos: Vec<String>,
 
-    /// Number of examples to generate
+    /// Number of examples to generate per repository
     #[clap(long, default_value_t = 5)]
     count: usize,
 
-    /// Maximum commits to scan before giving up
+    /// Maximum commits to scan per repository before giving up
     #[clap(long, default_value_t = 100)]
     max_commits: usize,
 
@@ -425,7 +425,7 @@ fn main() {
                 panic!("output dir is required");
             };
             let config = SynthesizeConfig {
-                repo_url: synth_args.repo.clone(),
+                repo_urls: synth_args.repos.clone(),
                 count: synth_args.count,
                 max_commits: synth_args.max_commits,
                 output_dir,

crates/edit_prediction_cli/src/synthesize.rs 🔗

@@ -12,6 +12,7 @@ use edit_prediction::{
     example_spec::ExampleSpec,
     udiff::{apply_diff_to_string, edits_for_diff},
 };
+use futures::stream::{FuturesUnordered, StreamExt};
 use indoc::indoc;
 use serde::{Deserialize, Serialize};
 use std::{
@@ -21,7 +22,8 @@ use std::{
 
 #[derive(Debug, Clone)]
 pub struct SynthesizeConfig {
-    pub repo_url: String,
+    pub repo_urls: Vec<String>,
+    /// Number of examples to generate per repository
     pub count: usize,
     pub max_commits: usize,
     pub output_dir: PathBuf,
@@ -57,16 +59,23 @@ impl SynthesizeState {
         Ok(())
     }
 
-    fn is_processed(&self, repo_url: &str, commit_sha: &str) -> bool {
-        self.repositories
-            .get(repo_url)
-            .is_some_and(|repo| repo.processed_commits.contains(commit_sha))
+    fn take_repo_state(&mut self, repo_url: &str) -> RepoState {
+        self.repositories.remove(repo_url).unwrap_or_default()
     }
 
-    fn mark_processed(&mut self, repo_url: &str, commit_sha: &str, examples_count: usize) {
-        let repo = self.repositories.entry(repo_url.to_string()).or_default();
-        repo.processed_commits.insert(commit_sha.to_string());
-        repo.examples_generated += examples_count;
+    fn merge_repo_state(&mut self, repo_url: String, repo_state: RepoState) {
+        self.repositories.insert(repo_url, repo_state);
+    }
+}
+
+impl RepoState {
+    fn is_processed(&self, commit_sha: &str) -> bool {
+        self.processed_commits.contains(commit_sha)
+    }
+
+    fn mark_processed(&mut self, commit_sha: &str, examples_count: usize) {
+        self.processed_commits.insert(commit_sha.to_string());
+        self.examples_generated += examples_count;
     }
 }
 
@@ -108,19 +117,71 @@ pub async fn run_synthesize(config: SynthesizeConfig) -> Result<()> {
     std::os::windows::fs::symlink_dir(&*FAILED_EXAMPLES_DIR, &*LATEST_FAILED_EXAMPLES_DIR)?;
 
     let progress = Progress::global();
-    progress.set_total_examples(config.count);
+    let total_examples = config.count * config.repo_urls.len();
+    progress.set_total_examples(total_examples);
+
+    let client = Arc::new(PlainLlmClient::new()?);
+    let config = Arc::new(config);
+
+    let mut futures: FuturesUnordered<_> = config
+        .repo_urls
+        .iter()
+        .map(|repo_url| {
+            let client = client.clone();
+            let repo_state = state.take_repo_state(repo_url);
+            let config = config.clone();
+            let repo_url = repo_url.clone();
+            async move {
+                let result = synthesize_repo(&client, repo_state, &config, &repo_url).await;
+                (repo_url, result)
+            }
+        })
+        .collect();
+
+    let mut errors = Vec::new();
+    while let Some((repo_url, result)) = futures.next().await {
+        match result {
+            Ok(repo_state) => {
+                state.merge_repo_state(repo_url, repo_state);
+            }
+            Err(e) => {
+                errors.push(e);
+            }
+        }
+    }
+
+    state.save()?;
+
+    progress.finalize();
 
-    let clone_progress = progress.start(Step::Synthesize, "clone");
-    let repo_path = ensure_repo_cloned(&config.repo_url).await?;
+    if let Some(first_error) = errors.into_iter().next() {
+        return Err(first_error);
+    }
+
+    Ok(())
+}
+
+async fn synthesize_repo(
+    client: &PlainLlmClient,
+    mut repo_state: RepoState,
+    config: &SynthesizeConfig,
+    repo_url: &str,
+) -> Result<RepoState> {
+    let progress = Progress::global();
+    let batch_size = config.max_commits;
+
+    let clone_progress = progress.start(Step::Synthesize, &format!("clone {}", repo_url));
+    let repo_path = ensure_repo_cloned(repo_url).await?;
     drop(clone_progress);
 
-    let client = PlainLlmClient::new()?;
     let mut examples_generated = 0;
     let mut commits_skipped = 0;
-    let batch_size = config.max_commits;
 
     'outer: loop {
-        let list_progress = progress.start(Step::Synthesize, "list-commits");
+        let list_progress = progress.start(
+            Step::Synthesize,
+            &format!("{}: list-commits", repo_name_from_url(repo_url)),
+        );
         let commits = list_commits(&repo_path, batch_size, commits_skipped).await?;
         drop(list_progress);
 
@@ -135,7 +196,7 @@ pub async fn run_synthesize(config: SynthesizeConfig) -> Result<()> {
                 break 'outer;
             }
 
-            if !config.fresh && state.is_processed(&config.repo_url, &commit.sha) {
+            if !config.fresh && repo_state.is_processed(&commit.sha) {
                 continue;
             }
 
@@ -143,8 +204,10 @@ pub async fn run_synthesize(config: SynthesizeConfig) -> Result<()> {
                 continue;
             }
 
+            let repo_name = repo_name_from_url(repo_url);
             let commit_label = format!(
-                "{} {}",
+                "{}: {} {}",
+                repo_name,
                 &commit.sha[..8],
                 truncate_message(&commit.message, 40)
             );
@@ -153,28 +216,26 @@ pub async fn run_synthesize(config: SynthesizeConfig) -> Result<()> {
             // Single Claude call to identify and copy hunks
             step_progress.set_substatus("analyzing...");
             let claude_response =
-                match analyze_commit(&client, &config, &commit, step_progress.clone()).await {
+                match analyze_commit(client, repo_url, &commit, step_progress.clone()).await {
                     Ok(Some(response)) => response,
                     Ok(None) => {
                         step_progress.set_info("no pattern", InfoStyle::Normal);
-                        state.mark_processed(&config.repo_url, &commit.sha, 0);
-                        state.save()?;
+                        repo_state.mark_processed(&commit.sha, 0);
                         continue;
                     }
                     Err(e) => {
                         step_progress.set_info(format!("error: {:?}", e), InfoStyle::Warning);
-                        state.mark_processed(&config.repo_url, &commit.sha, 0);
-                        state.save()?;
+                        repo_state.mark_processed(&commit.sha, 0);
                         continue;
                     }
                 };
 
             // Validate and build the example
             step_progress.set_substatus("validating...");
-            match build_example(&config, &commit, &repo_path, &claude_response).await {
+            match build_example(repo_url, &commit, &repo_path, &claude_response).await {
                 Ok(spec) => {
                     let timestamp = Local::now().format("%Y-%m-%d--%H-%M-%S");
-                    let filename = format!("{}.md", timestamp);
+                    let filename = format!("{}--{}.md", repo_name, timestamp);
                     let path = config.output_dir.join(&filename);
                     std::fs::write(&path, spec.to_markdown())?;
                     examples_generated += 1;
@@ -183,7 +244,7 @@ pub async fn run_synthesize(config: SynthesizeConfig) -> Result<()> {
                 Err(rejection_reason) => {
                     log::debug!("Example rejected: {}", rejection_reason);
                     let timestamp = Local::now().format("%Y-%m-%d--%H-%M-%S%.3f");
-                    let filename = format!("{}.md", timestamp);
+                    let filename = format!("{}--{}.md", repo_name, timestamp);
                     let path = FAILED_EXAMPLES_DIR.join(&filename);
                     let content = format_rejected_example(&claude_response, &rejection_reason);
                     if let Err(e) = std::fs::write(&path, content) {
@@ -193,13 +254,19 @@ pub async fn run_synthesize(config: SynthesizeConfig) -> Result<()> {
                 }
             }
 
-            state.mark_processed(&config.repo_url, &commit.sha, 1);
-            state.save()?;
+            repo_state.mark_processed(&commit.sha, 1);
         }
     }
 
-    progress.finalize();
-    Ok(())
+    Ok(repo_state)
+}
+
+fn repo_name_from_url(url: &str) -> String {
+    url.rsplit('/')
+        .next()
+        .unwrap_or(url)
+        .trim_end_matches(".git")
+        .to_string()
 }
 
 fn truncate_message(msg: &str, max_len: usize) -> String {
@@ -305,7 +372,7 @@ async fn list_commits(
     Ok(commits)
 }
 
-fn build_prompt(config: &SynthesizeConfig, commit: &CommitInfo) -> String {
+fn build_prompt(repo_url: &str, commit: &CommitInfo) -> String {
     format!(
         indoc! {r#"
             You are analyzing a git commit to construct a realistic edit prediction example.
@@ -439,7 +506,7 @@ fn build_prompt(config: &SynthesizeConfig, commit: &CommitInfo) -> String {
             - Must be SMALL: 1-15 changed lines (not counting context)
             - Must be clearly predictable from the edit history narrative
         "#},
-        repo_url = config.repo_url,
+        repo_url = repo_url,
         sha = commit.sha,
         message = commit.message,
         expanded_diff = commit.expanded_diff,
@@ -448,13 +515,13 @@ fn build_prompt(config: &SynthesizeConfig, commit: &CommitInfo) -> String {
 
 async fn analyze_commit(
     client: &PlainLlmClient,
-    config: &SynthesizeConfig,
+    repo_url: &str,
     commit: &CommitInfo,
     step_progress: Arc<StepProgress>,
 ) -> Result<Option<ClaudeResponse>> {
     use anthropic::{Message, RequestContent, Role};
 
-    let prompt = build_prompt(config, commit);
+    let prompt = build_prompt(repo_url, commit);
     let messages = vec![Message {
         role: Role::User,
         content: vec![RequestContent::Text {
@@ -652,7 +719,7 @@ fn split_into_hunks(diff: &str) -> Vec<String> {
 
 /// Validate Claude's output by applying diffs and build the ExampleSpec
 async fn build_example(
-    config: &SynthesizeConfig,
+    repo_url: &str,
     commit: &CommitInfo,
     repo_path: &Path,
     response: &ClaudeResponse,
@@ -715,7 +782,7 @@ async fn build_example(
     );
     let mut spec = ExampleSpec {
         name: response.name.clone(),
-        repository_url: config.repo_url.clone(),
+        repository_url: repo_url.to_string(),
         revision: commit.parent_sha.clone(),
         tags: Vec::new(),
         reasoning: Some(reasoning_with_source),