ep: Add a parameter to sample at most N near-duplicates (#49870)

Oleksiy Syvokon created

Duplicates are defined as cursor positions that have an approximate
Jaccard similarity greater than 0.5 (over token 5-grams).

From the resulting clusters of near-duplicates, we select up to N
examples that are maximally different from each other.


Release Notes:

- N/A

Change summary

Cargo.lock                             | 136 +++++++++++++++++++++++
crates/edit_prediction_cli/Cargo.toml  |   1 
crates/edit_prediction_cli/src/main.rs | 155 +++++++++++++++++++++++++++
3 files changed, 283 insertions(+), 9 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -5247,6 +5247,7 @@ dependencies = [
  "flate2",
  "fs",
  "futures 0.3.31",
+ "gaoya",
  "gpui",
  "gpui_platform",
  "gpui_tokio",
@@ -6541,6 +6542,12 @@ dependencies = [
  "libc",
 ]
 
+[[package]]
+name = "fuchsia-cprng"
+version = "0.1.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a06f77d526c1a601b7c4cdd98f54b5eaabffc14d5f2f0296febdc7f357c6d3ba"
+
 [[package]]
 name = "funty"
 version = "2.0.0"
@@ -6711,6 +6718,29 @@ dependencies = [
  "thread_local",
 ]
 
+[[package]]
+name = "gaoya"
+version = "0.2.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0c75195ebd4c5589a505e1f0bf81052c52f55dfa40c1afefac1f95b67846adb1"
+dependencies = [
+ "ahash 0.8.12",
+ "crossbeam-utils",
+ "fnv",
+ "itertools 0.10.5",
+ "num-traits",
+ "rand 0.8.5",
+ "rand_pcg",
+ "random_choice",
+ "rayon",
+ "seahash",
+ "sha-1",
+ "shingles",
+ "siphasher 0.3.11",
+ "smallvec",
+ "triomphe",
+]
+
 [[package]]
 name = "gemm"
 version = "0.17.1"
@@ -12287,7 +12317,7 @@ version = "0.11.3"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5"
 dependencies = [
- "siphasher",
+ "siphasher 1.0.1",
 ]
 
 [[package]]
@@ -12296,7 +12326,7 @@ version = "0.12.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "06005508882fb681fd97892ecff4b7fd0fee13ef1aa569f8695dae7ab9099981"
 dependencies = [
- "siphasher",
+ "siphasher 1.0.1",
 ]
 
 [[package]]
@@ -13343,6 +13373,29 @@ version = "0.7.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09"
 
+[[package]]
+name = "rand"
+version = "0.3.23"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "64ac302d8f83c0c1974bf758f6b041c6c8ada916fbb44a609158ca8b064cc76c"
+dependencies = [
+ "libc",
+ "rand 0.4.6",
+]
+
+[[package]]
+name = "rand"
+version = "0.4.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "552840b97013b1a26992c11eac34bdd778e464601a4c2054b5f0bff7c6761293"
+dependencies = [
+ "fuchsia-cprng",
+ "libc",
+ "rand_core 0.3.1",
+ "rdrand",
+ "winapi",
+]
+
 [[package]]
 name = "rand"
 version = "0.8.5"
@@ -13384,6 +13437,21 @@ dependencies = [
  "rand_core 0.9.3",
 ]
 
+[[package]]
+name = "rand_core"
+version = "0.3.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7a6fdeb83b075e8266dcc8762c22776f6877a63111121f5f8c7411e5be7eed4b"
+dependencies = [
+ "rand_core 0.4.2",
+]
+
+[[package]]
+name = "rand_core"
+version = "0.4.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9c33a3c44ca05fa6f1807d8e6743f3824e8509beca625669633be0acbdf509dc"
+
 [[package]]
 name = "rand_core"
 version = "0.6.4"
@@ -13412,6 +13480,24 @@ dependencies = [
  "rand 0.9.2",
 ]
 
+[[package]]
+name = "rand_pcg"
+version = "0.3.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "59cad018caf63deb318e5a4586d99a24424a364f40f1e5778c29aca23f4fc73e"
+dependencies = [
+ "rand_core 0.6.4",
+]
+
+[[package]]
+name = "random_choice"
+version = "0.3.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "09c8d23fe09a1d82566c84c9dfa810b0479c6dbbe190728274f68ee3a0c27dbf"
+dependencies = [
+ "rand 0.3.23",
+]
+
 [[package]]
 name = "range-alloc"
 version = "0.1.4"
@@ -13527,6 +13613,15 @@ dependencies = [
  "crossbeam-utils",
 ]
 
+[[package]]
+name = "rdrand"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "678054eb77286b51581ba43620cc911abf02758c91f93f479767aed0f90458b2"
+dependencies = [
+ "rand_core 0.3.1",
+]
+
 [[package]]
 name = "read-fonts"
 version = "0.35.0"
@@ -15308,6 +15403,17 @@ dependencies = [
  "zlog",
 ]
 
+[[package]]
+name = "sha-1"
+version = "0.10.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f5058ada175748e33390e40e872bd0fe59a19f265d0158daa551c5a88a76009c"
+dependencies = [
+ "cfg-if",
+ "cpufeatures",
+ "digest",
+]
+
 [[package]]
 name = "sha1"
 version = "0.10.6"
@@ -15386,6 +15492,12 @@ dependencies = [
  "dirs 6.0.0",
 ]
 
+[[package]]
+name = "shingles"
+version = "0.1.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "72bb884be1ddfbded5873be4672cf5aee71210ce0f8ae99787d158b9b72b5ca0"
+
 [[package]]
 name = "shlex"
 version = "1.3.0"
@@ -15515,6 +15627,12 @@ dependencies = [
  "time",
 ]
 
+[[package]]
+name = "siphasher"
+version = "0.3.11"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d"
+
 [[package]]
 name = "siphasher"
 version = "1.0.1"
@@ -16330,7 +16448,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "68c7541fff44b35860c1a7a47a7cadf3e4a304c457b58f9870d9706ece028afc"
 dependencies = [
  "kurbo",
- "siphasher",
+ "siphasher 1.0.1",
 ]
 
 [[package]]
@@ -18014,6 +18132,16 @@ dependencies = [
  "tree-sitter-language",
 ]
 
+[[package]]
+name = "triomphe"
+version = "0.1.15"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "dd69c5aa8f924c7519d6372789a74eac5b94fb0f8fcf0d4a97eb0bfc3e785f39"
+dependencies = [
+ "serde",
+ "stable_deref_trait",
+]
+
 [[package]]
 name = "try-lock"
 version = "0.2.5"
@@ -18384,7 +18512,7 @@ dependencies = [
  "roxmltree",
  "rustybuzz",
  "simplecss",
- "siphasher",
+ "siphasher 1.0.1",
  "strict-num",
  "svgtypes",
  "tiny-skia-path",

crates/edit_prediction_cli/Cargo.toml 🔗

@@ -65,6 +65,7 @@ similar = "2.7.0"
 flate2 = "1.1.8"
 toml.workspace = true
 rust-embed = { workspace = true, features = ["debug-embed"] }
+gaoya = "0.2.0"
 
 # Wasmtime is included as a dependency in order to enable the same
 # features that are enabled in Zed.

crates/edit_prediction_cli/src/main.rs 🔗

@@ -27,10 +27,13 @@ mod synthesize;
 mod truncate_expected_patch;
 mod word_diff;
 use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
-use collections::HashSet;
+use collections::{HashMap, HashSet};
 use edit_prediction::EditPredictionStore;
 use futures::channel::mpsc;
 use futures::{SinkExt as _, StreamExt as _};
+use gaoya::minhash::{
+    MinHashIndex, MinHasher, MinHasher32, calculate_minhash_params, compute_minhash_similarity,
+};
 use gpui::{AppContext as _, BackgroundExecutor, Task};
 use zeta_prompt::ZetaFormat;
 
@@ -77,9 +80,13 @@ struct EpArgs {
     /// Filter examples by repository
     #[clap(long, global = true)]
     repo: Option<String>,
+    /// Deduplicate by cursor position and keep at most this many examples per cluster
+    #[clap(long, global = true)]
+    max_duplicates: Option<usize>,
     #[command(subcommand)]
     command: Option<Command>,
-    #[clap(global = true, help = INPUTS_HELP)]
+    /// Input file paths
+    #[clap(global = true)]
     inputs: Vec<PathBuf>,
     #[arg(long, short, global = true)]
     output: Option<PathBuf>,
@@ -171,7 +178,7 @@ Examples:
 #[derive(Subcommand, Debug, Clone)]
 enum Command {
     /// Read examples from files or fetch from Snowflake, output as .jsonl
-    Read,
+    Read(ReadArgs),
     /// Create git worktrees for each example and load file contents
     LoadProject,
     /// Retrieve context for input examples.
@@ -215,7 +222,7 @@ enum Command {
 impl Display for Command {
     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
         match self {
-            Command::Read => write!(f, "read"),
+            Command::Read(_) => write!(f, "read"),
             Command::LoadProject => write!(f, "load-project"),
             Command::Context => write!(f, "context"),
             Command::FormatPrompt(args) => {
@@ -259,6 +266,10 @@ impl Display for Command {
     }
 }
 
+#[derive(Debug, Args, Clone)]
+#[command(after_help = INPUTS_HELP)]
+struct ReadArgs {}
+
 #[derive(Debug, Args, Clone)]
 struct FormatPromptArgs {
     #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
@@ -481,6 +492,136 @@ const MIN_CAPTURE_VERSION: pull_examples::MinCaptureVersion = pull_examples::Min
     patch: 1,
 };
 
+fn deduplicate_examples(examples: &mut Vec<Example>, max_per_cluster: usize) {
+    let total_before_exact = examples.len();
+    let mut seen_positions = HashSet::default();
+    examples.retain(|example| seen_positions.insert(example.spec.cursor_position.clone()));
+    log::info!(
+        "exact duplicate filter: {total_before_exact} examples → {} examples ({} removed)",
+        examples.len(),
+        total_before_exact - examples.len(),
+    );
+
+    const JACCARD_THRESHOLD: f64 = 0.5;
+    const NUM_HASHES: usize = 128;
+    const TOKEN_NGRAM_SIZE: usize = 5;
+
+    let (num_bands, band_width) = calculate_minhash_params(JACCARD_THRESHOLD, NUM_HASHES);
+    let num_hashes = num_bands * band_width;
+    let minhasher = MinHasher32::new(num_hashes);
+    let mut index: MinHashIndex<u32, usize> =
+        MinHashIndex::new(num_bands, band_width, JACCARD_THRESHOLD);
+
+    let signatures: Vec<Vec<u32>> = examples
+        .iter()
+        .map(|example| {
+            let shingles = code_token_ngrams(&example.spec.cursor_position, TOKEN_NGRAM_SIZE);
+            minhasher.create_signature(shingles.iter())
+        })
+        .collect();
+
+    for (id, signature) in signatures.iter().enumerate() {
+        index.insert(id, signature.clone());
+    }
+
+    // Build clusters via union-find on LSH candidate pairs.
+    let mut parent: Vec<usize> = (0..examples.len()).collect();
+
+    fn find(parent: &mut Vec<usize>, mut x: usize) -> usize {
+        while parent[x] != x {
+            parent[x] = parent[parent[x]];
+            x = parent[x];
+        }
+        x
+    }
+
+    for (id, signature) in signatures.iter().enumerate() {
+        for candidate in index.query_owned(signature) {
+            let (a, b) = (find(&mut parent, id), find(&mut parent, candidate));
+            if a != b {
+                parent[a] = b;
+            }
+        }
+    }
+
+    let mut clusters: HashMap<usize, Vec<usize>> = HashMap::default();
+    for id in 0..examples.len() {
+        clusters.entry(find(&mut parent, id)).or_default().push(id);
+    }
+
+    let mut keep: HashSet<usize> = HashSet::default();
+    for members in clusters.values() {
+        let selected = greedy_max_min_diverse(members, &signatures, max_per_cluster);
+        keep.extend(selected);
+    }
+
+    let total = examples.len();
+    let mut kept_indices: Vec<usize> = keep.into_iter().collect();
+    kept_indices.sort();
+
+    let mut retained = Vec::with_capacity(kept_indices.len());
+    for index in kept_indices.into_iter().rev() {
+        retained.push(examples.swap_remove(index));
+    }
+    retained.reverse();
+
+    *examples = retained;
+    log::info!(
+        "near-duplicate filter: {total} examples → {} examples ({} removed)",
+        examples.len(),
+        total - examples.len(),
+    );
+}
+
+fn greedy_max_min_diverse(members: &[usize], signatures: &[Vec<u32>], k: usize) -> Vec<usize> {
+    if members.len() <= k {
+        return members.to_vec();
+    }
+
+    let mut selected = vec![members[0]];
+    let mut min_dist: HashMap<usize, f64> = HashMap::default();
+    for &member in &members[1..] {
+        let dist = 1.0 - compute_minhash_similarity(&signatures[selected[0]], &signatures[member]);
+        min_dist.insert(member, dist);
+    }
+
+    while selected.len() < k {
+        let &best = min_dist
+            .iter()
+            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
+            .map(|(id, _)| id)
+            .expect("min_dist should not be empty when selected.len() < k");
+        selected.push(best);
+        min_dist.remove(&best);
+
+        let best_sig = &signatures[best];
+        for (member, current_min) in min_dist.iter_mut() {
+            let dist = 1.0 - compute_minhash_similarity(best_sig, &signatures[*member]);
+            if dist < *current_min {
+                *current_min = dist;
+            }
+        }
+    }
+
+    selected
+}
+
+fn code_token_ngrams(code: &str, ngram_size: usize) -> Vec<String> {
+    let tokens: Vec<&str> = word_diff::tokenize(code)
+        .into_iter()
+        .filter(|t| !t.trim().is_empty())
+        .collect();
+
+    if tokens.len() < ngram_size {
+        return vec![tokens.join("\0")];
+    }
+
+    tokens
+        .windows(ngram_size)
+        .map(|window| window.join("\0"))
+        .collect()
+}
+
 async fn load_examples(
     http_client: Arc<dyn http_client::HttpClient>,
     args: &EpArgs,
@@ -623,6 +764,10 @@ async fn load_examples(
         }
     }
 
+    if let Some(max_duplicates) = args.max_duplicates {
+        deduplicate_examples(&mut examples, max_duplicates);
+    }
+
     if let Some(limit) = args.limit {
         examples.truncate(limit);
     }
@@ -922,7 +1067,7 @@ fn main() {
 
                                 let result = async {
                                     match &command {
-                                        Command::Read => {}
+                                        Command::Read(_) => {}
                                         Command::LoadProject => {
                                             run_load_project(
                                                 example,