From 450c66ce6e24ec10111fc8dd75711663b2e01b5e Mon Sep 17 00:00:00 2001 From: Oleksiy Syvokon Date: Mon, 23 Feb 2026 12:56:57 +0200 Subject: [PATCH] ep: Add a parameter to sample at most N near-duplicates (#49870) Duplicates are defined as cursor positions that have an approximate Jaccard similarity greater than 0.5 (over token 5-grams). From the resulting clusters of near-duplicates, we select up to N examples that are maximally different from each other. Release Notes: - N/A --- Cargo.lock | 136 +++++++++++++++++++++- crates/edit_prediction_cli/Cargo.toml | 1 + crates/edit_prediction_cli/src/main.rs | 155 ++++++++++++++++++++++++- 3 files changed, 283 insertions(+), 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cab5f495a3f6103f4f01ebcbeb46c45b4b63053d..5c5328080cd94f78957a6be4930076202c24d51d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5247,6 +5247,7 @@ dependencies = [ "flate2", "fs", "futures 0.3.31", + "gaoya", "gpui", "gpui_platform", "gpui_tokio", @@ -6541,6 +6542,12 @@ dependencies = [ "libc", ] +[[package]] +name = "fuchsia-cprng" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06f77d526c1a601b7c4cdd98f54b5eaabffc14d5f2f0296febdc7f357c6d3ba" + [[package]] name = "funty" version = "2.0.0" @@ -6711,6 +6718,29 @@ dependencies = [ "thread_local", ] +[[package]] +name = "gaoya" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c75195ebd4c5589a505e1f0bf81052c52f55dfa40c1afefac1f95b67846adb1" +dependencies = [ + "ahash 0.8.12", + "crossbeam-utils", + "fnv", + "itertools 0.10.5", + "num-traits", + "rand 0.8.5", + "rand_pcg", + "random_choice", + "rayon", + "seahash", + "sha-1", + "shingles", + "siphasher 0.3.11", + "smallvec", + "triomphe", +] + [[package]] name = "gemm" version = "0.17.1" @@ -12287,7 +12317,7 @@ version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5" dependencies = [ - "siphasher", + "siphasher 1.0.1", ] [[package]] @@ -12296,7 +12326,7 @@ version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06005508882fb681fd97892ecff4b7fd0fee13ef1aa569f8695dae7ab9099981" dependencies = [ - "siphasher", + "siphasher 1.0.1", ] [[package]] @@ -13343,6 +13373,29 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" +[[package]] +name = "rand" +version = "0.3.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64ac302d8f83c0c1974bf758f6b041c6c8ada916fbb44a609158ca8b064cc76c" +dependencies = [ + "libc", + "rand 0.4.6", +] + +[[package]] +name = "rand" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "552840b97013b1a26992c11eac34bdd778e464601a4c2054b5f0bff7c6761293" +dependencies = [ + "fuchsia-cprng", + "libc", + "rand_core 0.3.1", + "rdrand", + "winapi", +] + [[package]] name = "rand" version = "0.8.5" @@ -13384,6 +13437,21 @@ dependencies = [ "rand_core 0.9.3", ] +[[package]] +name = "rand_core" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a6fdeb83b075e8266dcc8762c22776f6877a63111121f5f8c7411e5be7eed4b" +dependencies = [ + "rand_core 0.4.2", +] + +[[package]] +name = "rand_core" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c33a3c44ca05fa6f1807d8e6743f3824e8509beca625669633be0acbdf509dc" + [[package]] name = "rand_core" version = "0.6.4" @@ -13412,6 +13480,24 @@ dependencies = [ "rand 0.9.2", ] +[[package]] +name = "rand_pcg" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59cad018caf63deb318e5a4586d99a24424a364f40f1e5778c29aca23f4fc73e" +dependencies = [ + "rand_core 0.6.4", +] + +[[package]] +name = "random_choice" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09c8d23fe09a1d82566c84c9dfa810b0479c6dbbe190728274f68ee3a0c27dbf" +dependencies = [ + "rand 0.3.23", +] + [[package]] name = "range-alloc" version = "0.1.4" @@ -13527,6 +13613,15 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "rdrand" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "678054eb77286b51581ba43620cc911abf02758c91f93f479767aed0f90458b2" +dependencies = [ + "rand_core 0.3.1", +] + [[package]] name = "read-fonts" version = "0.35.0" @@ -15308,6 +15403,17 @@ dependencies = [ "zlog", ] +[[package]] +name = "sha-1" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f5058ada175748e33390e40e872bd0fe59a19f265d0158daa551c5a88a76009c" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sha1" version = "0.10.6" @@ -15386,6 +15492,12 @@ dependencies = [ "dirs 6.0.0", ] +[[package]] +name = "shingles" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72bb884be1ddfbded5873be4672cf5aee71210ce0f8ae99787d158b9b72b5ca0" + [[package]] name = "shlex" version = "1.3.0" @@ -15515,6 +15627,12 @@ dependencies = [ "time", ] +[[package]] +name = "siphasher" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d" + [[package]] name = "siphasher" version = "1.0.1" @@ -16330,7 +16448,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68c7541fff44b35860c1a7a47a7cadf3e4a304c457b58f9870d9706ece028afc" dependencies = [ "kurbo", - "siphasher", + "siphasher 1.0.1", ] [[package]] @@ -18014,6 +18132,16 @@ dependencies = [ "tree-sitter-language", ] +[[package]] +name = "triomphe" +version = "0.1.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd69c5aa8f924c7519d6372789a74eac5b94fb0f8fcf0d4a97eb0bfc3e785f39" +dependencies = [ + "serde", + "stable_deref_trait", +] + [[package]] name = "try-lock" version = "0.2.5" @@ -18384,7 +18512,7 @@ dependencies = [ "roxmltree", "rustybuzz", "simplecss", - "siphasher", + "siphasher 1.0.1", "strict-num", "svgtypes", "tiny-skia-path", diff --git a/crates/edit_prediction_cli/Cargo.toml b/crates/edit_prediction_cli/Cargo.toml index 82ddfe504aa7519597696c2b13531b0e14cfcda3..3cb79b44528b66c48c439f7e2433addb34901000 100644 --- a/crates/edit_prediction_cli/Cargo.toml +++ b/crates/edit_prediction_cli/Cargo.toml @@ -65,6 +65,7 @@ similar = "2.7.0" flate2 = "1.1.8" toml.workspace = true rust-embed = { workspace = true, features = ["debug-embed"] } +gaoya = "0.2.0" # Wasmtime is included as a dependency in order to enable the same # features that are enabled in Zed. diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index 01845b13c5621973c148988a168c89efb7a46210..836cd433657199125100051dc428bd7636360d30 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -27,10 +27,13 @@ mod synthesize; mod truncate_expected_patch; mod word_diff; use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum}; -use collections::HashSet; +use collections::{HashMap, HashSet}; use edit_prediction::EditPredictionStore; use futures::channel::mpsc; use futures::{SinkExt as _, StreamExt as _}; +use gaoya::minhash::{ + MinHashIndex, MinHasher, MinHasher32, calculate_minhash_params, compute_minhash_similarity, +}; use gpui::{AppContext as _, BackgroundExecutor, Task}; use zeta_prompt::ZetaFormat; @@ -77,9 +80,13 @@ struct EpArgs { /// Filter examples by repository #[clap(long, global = true)] repo: Option, + /// Deduplicate by cursor position and keep at most this many examples per cluster + #[clap(long, global = true)] + max_duplicates: Option, #[command(subcommand)] command: Option, - #[clap(global = true, help = INPUTS_HELP)] + /// Input file paths + #[clap(global = true)] inputs: Vec, #[arg(long, short, global = true)] output: Option, @@ -171,7 +178,7 @@ Examples: #[derive(Subcommand, Debug, Clone)] enum Command { /// Read examples from files or fetch from Snowflake, output as .jsonl - Read, + Read(ReadArgs), /// Create git worktrees for each example and load file contents LoadProject, /// Retrieve context for input examples. @@ -215,7 +222,7 @@ enum Command { impl Display for Command { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Command::Read => write!(f, "read"), + Command::Read(_) => write!(f, "read"), Command::LoadProject => write!(f, "load-project"), Command::Context => write!(f, "context"), Command::FormatPrompt(args) => { @@ -259,6 +266,10 @@ impl Display for Command { } } +#[derive(Debug, Args, Clone)] +#[command(after_help = INPUTS_HELP)] +struct ReadArgs {} + #[derive(Debug, Args, Clone)] struct FormatPromptArgs { #[clap(long, short('p'), default_value_t = PredictionProvider::default())] @@ -481,6 +492,136 @@ const MIN_CAPTURE_VERSION: pull_examples::MinCaptureVersion = pull_examples::Min patch: 1, }; +fn deduplicate_examples(examples: &mut Vec, max_per_cluster: usize) { + let total_before_exact = examples.len(); + let mut seen_positions = HashSet::default(); + examples.retain(|example| seen_positions.insert(example.spec.cursor_position.clone())); + log::info!( + "exact duplicate filter: {total_before_exact} examples → {} examples ({} removed)", + examples.len(), + total_before_exact - examples.len(), + ); + + const JACCARD_THRESHOLD: f64 = 0.5; + const NUM_HASHES: usize = 128; + const TOKEN_NGRAM_SIZE: usize = 5; + + let (num_bands, band_width) = calculate_minhash_params(JACCARD_THRESHOLD, NUM_HASHES); + let num_hashes = num_bands * band_width; + let minhasher = MinHasher32::new(num_hashes); + let mut index: MinHashIndex = + MinHashIndex::new(num_bands, band_width, JACCARD_THRESHOLD); + + let signatures: Vec> = examples + .iter() + .map(|example| { + let shingles = code_token_ngrams(&example.spec.cursor_position, TOKEN_NGRAM_SIZE); + minhasher.create_signature(shingles.iter()) + }) + .collect(); + + for (id, signature) in signatures.iter().enumerate() { + index.insert(id, signature.clone()); + } + + // Build clusters via union-find on LSH candidate pairs. + let mut parent: Vec = (0..examples.len()).collect(); + + fn find(parent: &mut Vec, mut x: usize) -> usize { + while parent[x] != x { + parent[x] = parent[parent[x]]; + x = parent[x]; + } + x + } + + for (id, signature) in signatures.iter().enumerate() { + for candidate in index.query_owned(signature) { + let (a, b) = (find(&mut parent, id), find(&mut parent, candidate)); + if a != b { + parent[a] = b; + } + } + } + + let mut clusters: HashMap> = HashMap::default(); + for id in 0..examples.len() { + clusters.entry(find(&mut parent, id)).or_default().push(id); + } + + let mut keep: HashSet = HashSet::default(); + for members in clusters.values() { + let selected = greedy_max_min_diverse(members, &signatures, max_per_cluster); + keep.extend(selected); + } + + let total = examples.len(); + let mut kept_indices: Vec = keep.into_iter().collect(); + kept_indices.sort(); + + let mut retained = Vec::with_capacity(kept_indices.len()); + for index in kept_indices.into_iter().rev() { + retained.push(examples.swap_remove(index)); + } + retained.reverse(); + + *examples = retained; + log::info!( + "near-duplicate filter: {total} examples → {} examples ({} removed)", + examples.len(), + total - examples.len(), + ); +} + +fn greedy_max_min_diverse(members: &[usize], signatures: &[Vec], k: usize) -> Vec { + if members.len() <= k { + return members.to_vec(); + } + + let mut selected = vec![members[0]]; + let mut min_dist: HashMap = HashMap::default(); + for &member in &members[1..] { + let dist = 1.0 - compute_minhash_similarity(&signatures[selected[0]], &signatures[member]); + min_dist.insert(member, dist); + } + + while selected.len() < k { + let &best = min_dist + .iter() + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) + .map(|(id, _)| id) + .expect("min_dist should not be empty when selected.len() < k"); + selected.push(best); + min_dist.remove(&best); + + let best_sig = &signatures[best]; + for (member, current_min) in min_dist.iter_mut() { + let dist = 1.0 - compute_minhash_similarity(best_sig, &signatures[*member]); + if dist < *current_min { + *current_min = dist; + } + } + } + + selected +} + +fn code_token_ngrams(code: &str, ngram_size: usize) -> Vec { + let tokens: Vec<&str> = word_diff::tokenize(code) + .into_iter() + .filter(|t| !t.trim().is_empty()) + .collect(); + + if tokens.len() < ngram_size { + return vec![tokens.join("\0")]; + } + + tokens + .windows(ngram_size) + .map(|window| window.join("\0")) + .collect() +} + async fn load_examples( http_client: Arc, args: &EpArgs, @@ -623,6 +764,10 @@ async fn load_examples( } } + if let Some(max_duplicates) = args.max_duplicates { + deduplicate_examples(&mut examples, max_duplicates); + } + if let Some(limit) = args.limit { examples.truncate(limit); } @@ -922,7 +1067,7 @@ fn main() { let result = async { match &command { - Command::Read => {} + Command::Read(_) => {} Command::LoadProject => { run_load_project( example,