diff --git a/Cargo.lock b/Cargo.lock index d7d354fa9a366159e68ad0eaa0134fffbf3cbac2..ebc40ec0e58e954d76ba3a31358994cf7cab6669 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5295,12 +5295,14 @@ dependencies = [ "pretty_assertions", "project", "prompt_store", + "rand 0.9.2", "release_channel", "reqwest_client", "serde", "serde_json", "settings", "shellexpand 2.1.2", + "similar", "smol", "sqlez", "sqlez_macros", @@ -15077,6 +15079,12 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e" +[[package]] +name = "similar" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" + [[package]] name = "simple_asn1" version = "0.6.3" diff --git a/crates/edit_prediction_cli/Cargo.toml b/crates/edit_prediction_cli/Cargo.toml index b6bace2a2c080626126af96f9ef51e435d6ab8fa..36f264c70ed579865b3af6f25ac1d6690c89603d 100644 --- a/crates/edit_prediction_cli/Cargo.toml +++ b/crates/edit_prediction_cli/Cargo.toml @@ -17,7 +17,7 @@ anyhow.workspace = true anthropic.workspace = true http_client.workspace = true chrono.workspace = true -clap.workspace = true +clap = "4" client.workspace = true cloud_llm_client.workspace= true collections.workspace = true @@ -55,6 +55,8 @@ watch.workspace = true edit_prediction = { workspace = true, features = ["cli-support"] } wasmtime.workspace = true zeta_prompt.workspace = true +rand.workspace = true +similar = "2.7.0" # Wasmtime is included as a dependency in order to enable the same # features that are enabled in Zed. diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index 6074bed9b625fc7150442f51440bbb415560aa58..b54ae89409adcc496f56d503e994df3132e76dc7 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -9,8 +9,10 @@ mod metrics; mod paths; mod predict; mod progress; +mod reorder_patch; mod retrieve_context; mod score; +mod split_commit; mod synthesize; use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum}; @@ -30,6 +32,7 @@ use crate::predict::run_prediction; use crate::progress::Progress; use crate::retrieve_context::run_context_retrieval; use crate::score::run_scoring; +use crate::split_commit::SplitCommitArgs; use crate::synthesize::{SynthesizeConfig, run_synthesize}; #[derive(Parser, Debug)] @@ -74,6 +77,8 @@ enum Command { Synthesize(SynthesizeArgs), /// Remove git repositories and worktrees Clean, + /// Generate an evaluation example by splitting a chronologically-ordered commit + SplitCommit(SplitCommitArgs), } impl Display for Command { @@ -127,6 +132,7 @@ impl Display for Command { write!(f, "synthesize --repo={}", args.repo) } Command::Clean => write!(f, "clean"), + Command::SplitCommit(_) => write!(f, "split-commit"), } } } @@ -235,6 +241,13 @@ fn main() { }); return; } + Command::SplitCommit(split_commit_args) => { + if let Err(error) = split_commit::run_split_commit(split_commit_args) { + eprintln!("{error:#}"); + std::process::exit(1); + } + return; + } _ => {} } @@ -302,7 +315,9 @@ fn main() { run_scoring(example, &args, app_state.clone(), cx.clone()) .await?; } - Command::Clean | Command::Synthesize(_) => { + Command::Clean + | Command::Synthesize(_) + | Command::SplitCommit(_) => { unreachable!() } } diff --git a/crates/edit_prediction_cli/src/reorder_patch.rs b/crates/edit_prediction_cli/src/reorder_patch.rs new file mode 100644 index 0000000000000000000000000000000000000000..e1fd47624d7484346acad70d7ba87635106b3c43 --- /dev/null +++ b/crates/edit_prediction_cli/src/reorder_patch.rs @@ -0,0 +1,1462 @@ +#![allow(unused)] + +use std::collections::{BTreeMap, BTreeSet, HashMap}; + +/// Reorder selected groups of edits (additions & deletions) into a new patch. +/// +/// Intuition: +/// Think of the original patch as a timeline of atomic edit indices (0..N), +/// where one edit is one deleted or inserted line. +/// This function recombines these edits into a new patch which can be thought +/// of as a sequence of patches. +/// +/// You provide `edits_order` describing logical chunks (e.g., "write a feature", +/// "refactor", "add tests"). For each group the function: +/// 1. Extracts those edits +/// 2. Appends them to the output patch +/// 3. Removes them from an internal remainder so subsequent original indices +/// still point to the right (yet-to-be-extracted) edits. +/// +/// The returned `Patch` contains only the edits you listed, emitted group by +/// group. The leftover remainder is discarded. +/// +/// Parameters: +/// * `patch` - Source patch +/// * `edits_order` - Vector of sets of original (0-based) edit indexes +/// +/// Returns: +/// * A new `Patch` containing the grouped edits in the requested order. +/// +/// Example: +/// ```rust +/// use std::collections::BTreeSet; +/// use reorder_patch::{Patch, reorder_edits}; +/// +/// // Edits (indexes): 0:-old, 1:+new, 2:-old2, 3:+new2, 4:+added +/// let diff = "\ +/// --- a/a.txt +/// +++ b/a.txt +/// @@ -1,3 +1,3 @@ +/// one +/// -old +/// +new +/// end +/// @@ -5,3 +5,4 @@ +/// tail +/// -old2 +/// +new2 +/// +added +/// fin +/// "; +/// let patch = Patch::parse_unified_diff(diff); +/// +/// // First take the part of the second hunk's edits (2), +/// // then the first hunk (0,1), then the rest of the second hunk (3,4) +/// let order = vec![BTreeSet::from([2]), BTreeSet::from([0, 1]), BTreeSet::from([3, 4])]; +/// let reordered = reorder_edits(&patch, order); +/// println!("{}", reordered.to_string()); +/// ``` +pub fn reorder_edits(patch: &Patch, edits_order: Vec>) -> Patch { + let mut result = Patch { + header: patch.header.clone(), + hunks: Vec::new(), + }; + + let mut remainder = patch.clone(); + + // Indexes in `edits_order` will shift as we apply edits. + // This structure maps the original index to the actual index. + let stats = patch.stats(); + let total_edits = stats.added + stats.removed; + let mut indexes_map = BTreeMap::from_iter((0..total_edits).map(|i| (i, Some(i)))); + + for patch_edits_order in edits_order { + // Skip duplicated indexes that were already processed + let patch_edits_order = patch_edits_order + .into_iter() + .filter(|&i| indexes_map[&i].is_some()) // skip duplicated indexes + .collect::>(); + + if patch_edits_order.is_empty() { + continue; + } + + let order = patch_edits_order + .iter() + .map(|&i| { + indexes_map[&i].unwrap_or_else(|| panic!("Edit index {i} has been already used. Perhaps your spec contains duplicates")) + }) + .collect::>(); + + let extracted; + (extracted, remainder) = extract_edits(&remainder, &order); + + result.hunks.extend(extracted.hunks); + + // Update indexes_map to reflect applied edits. For example: + // + // Original_index | Removed? | Mapped_value + // 0 | false | 0 + // 1 | true | None + // 2 | true | None + // 3 | false | 1 + + for index in patch_edits_order { + indexes_map.insert(index, None); + for j in (index + 1)..total_edits { + if let Some(val) = indexes_map[&j] { + indexes_map.insert(j, Some(val - 1)); + } + } + } + } + + result +} + +/// Split a patch into (extracted, remainder) based on a set of edit indexes. +/// The first returned patch contains only the chosen edits; the second contains +/// everything else with those edits applied (converted into context). +pub fn extract_edits(patch: &Patch, edit_indexes: &BTreeSet) -> (Patch, Patch) { + let mut extracted = patch.clone(); + let mut remainder = patch.clone(); + + let stats = patch.stats(); + let num_edits = stats.added + stats.removed; + let this_edits = edit_indexes.iter().cloned().collect::>(); + let other_edits = (0..num_edits) + .filter(|i| !edit_indexes.contains(i)) + .collect(); + + remove_edits(&mut extracted, other_edits); + apply_edits(&mut remainder, this_edits); + + (extracted, remainder) +} + +#[derive(Debug, Default, Clone)] +pub struct Patch { + pub header: String, + pub hunks: Vec, +} + +pub struct DiffStats { + pub added: usize, + pub removed: usize, +} + +impl ToString for Patch { + fn to_string(&self) -> String { + let mut result = self.header.clone(); + for hunk in &self.hunks { + let current_file = hunk.filename.clone(); + result.push_str(&format!("--- a/{}\n", current_file)); + result.push_str(&format!("+++ b/{}\n", current_file)); + result.push_str(&hunk.to_string()); + } + + result + } +} + +impl Patch { + /// Parse a unified diff (git style) string into a `Patch`. + pub fn parse_unified_diff(unified_diff: &str) -> Patch { + let mut current_file = String::new(); + let mut is_filename_inherited = false; + let mut hunk = Hunk::default(); + let mut patch = Patch::default(); + let mut in_header = true; + + for line in unified_diff.lines() { + if line.starts_with("--- ") || line.starts_with("+++ ") || line.starts_with("@@") { + in_header = false; + } + + if in_header { + patch.header.push_str(format!("{}\n", &line).as_ref()); + continue; + } + + if line.starts_with("@@") { + if !hunk.lines.is_empty() { + patch.hunks.push(hunk); + } + hunk = Hunk::from_header(line, ¤t_file, is_filename_inherited); + is_filename_inherited = true; + } else if let Some(path) = line.strip_prefix("--- ") { + is_filename_inherited = false; + current_file = path.trim().strip_prefix("a/").unwrap_or(path).into(); + } else if let Some(path) = line.strip_prefix("+++ ") { + is_filename_inherited = false; + current_file = path.trim().strip_prefix("b/").unwrap_or(path).into(); + } else if let Some(line) = line.strip_prefix("+") { + hunk.lines.push(PatchLine::Addition(line.to_string())); + } else if let Some(line) = line.strip_prefix("-") { + hunk.lines.push(PatchLine::Deletion(line.to_string())); + } else if let Some(line) = line.strip_prefix(" ") { + hunk.lines.push(PatchLine::Context(line.to_string())); + } else { + hunk.lines.push(PatchLine::Garbage(line.to_string())); + } + } + + if !hunk.lines.is_empty() { + patch.hunks.push(hunk); + } + + let header_lines = patch.header.lines().collect::>(); + let len = header_lines.len(); + if len >= 2 { + if header_lines[len - 2].starts_with("diff --git") + && header_lines[len - 1].starts_with("index ") + { + patch.header = header_lines[..len - 2].join("\n") + "\n"; + } + } + if patch.header.trim().is_empty() { + patch.header = String::new(); + } + + patch + } + + /// Drop hunks that contain no additions or deletions. + pub fn remove_empty_hunks(&mut self) { + self.hunks.retain(|hunk| { + hunk.lines + .iter() + .any(|line| matches!(line, PatchLine::Addition(_) | PatchLine::Deletion(_))) + }); + } + + /// Make sure there are no more than `context_lines` lines of context around each change. + pub fn normalize_hunks(&mut self, context_lines: usize) { + for hunk in &mut self.hunks { + // Find indices of all changes (additions and deletions) + let change_indices: Vec = hunk + .lines + .iter() + .enumerate() + .filter_map(|(i, line)| match line { + PatchLine::Addition(_) | PatchLine::Deletion(_) => Some(i), + _ => None, + }) + .collect(); + + // If there are no changes, clear the hunk (it's all context) + if change_indices.is_empty() { + hunk.lines.clear(); + hunk.old_count = 0; + hunk.new_count = 0; + continue; + } + + // Determine the range to keep + let first_change = change_indices[0]; + let last_change = change_indices[change_indices.len() - 1]; + + let start = first_change.saturating_sub(context_lines); + let end = (last_change + context_lines + 1).min(hunk.lines.len()); + + // Count lines trimmed from the beginning + let (old_lines_before, new_lines_before) = count_lines(&hunk.lines[0..start]); + + // Keep only the lines in range + garbage + let garbage_before = hunk.lines[..start] + .iter() + .filter(|line| matches!(line, PatchLine::Garbage(_))); + let garbage_after = hunk.lines[end..] + .iter() + .filter(|line| matches!(line, PatchLine::Garbage(_))); + + hunk.lines = garbage_before + .chain(hunk.lines[start..end].iter()) + .chain(garbage_after) + .cloned() + .collect(); + + // Update hunk header + let (old_count, new_count) = count_lines(&hunk.lines); + hunk.old_start += old_lines_before as isize; + hunk.new_start += new_lines_before as isize; + hunk.old_count = old_count as isize; + hunk.new_count = new_count as isize; + } + } + + /// Count total added and removed lines + pub fn stats(&self) -> DiffStats { + let mut added = 0; + let mut removed = 0; + + for hunk in &self.hunks { + for line in &hunk.lines { + match line { + PatchLine::Addition(_) => added += 1, + PatchLine::Deletion(_) => removed += 1, + _ => {} + } + } + } + + DiffStats { added, removed } + } +} + +#[derive(Debug, Default, Clone)] +pub struct Hunk { + pub old_start: isize, + pub old_count: isize, + pub new_start: isize, + pub new_count: isize, + pub comment: String, + pub filename: String, + pub is_filename_inherited: bool, + pub lines: Vec, +} + +impl ToString for Hunk { + fn to_string(&self) -> String { + let header = self.header_string(); + let lines = self + .lines + .iter() + .map(|line| line.to_string() + "\n") + .collect::>() + .join(""); + format!("{header}\n{lines}") + } +} + +impl Hunk { + /// Render the hunk header + pub fn header_string(&self) -> String { + format!( + "@@ -{},{} +{},{} @@ {}", + self.old_start, + self.old_count, + self.new_start, + self.new_count, + self.comment.clone() + ) + .trim_end() + .into() + } + + /// Create a `Hunk` from a raw header line and associated filename. + pub fn from_header(header: &str, filename: &str, is_filename_inherited: bool) -> Self { + let (old_start, old_count, new_start, new_count, comment) = Self::parse_hunk_header(header); + Self { + old_start, + old_count, + new_start, + new_count, + comment, + filename: filename.to_string(), + is_filename_inherited, + lines: Vec::new(), + } + } + + /// Parse hunk headers like `@@ -3,2 +3,2 @@ some garbage" + fn parse_hunk_header(line: &str) -> (isize, isize, isize, isize, String) { + let header_part = line.trim_start_matches("@@").trim(); + let parts: Vec<&str> = header_part.split_whitespace().collect(); + + if parts.len() < 2 { + return (0, 0, 0, 0, String::new()); + } + + let old_part = parts[0].trim_start_matches('-'); + let new_part = parts[1].trim_start_matches('+'); + + let (old_start, old_count) = Hunk::parse_hunk_header_range(old_part); + let (new_start, new_count) = Hunk::parse_hunk_header_range(new_part); + + let comment = if parts.len() > 2 { + parts[2..] + .join(" ") + .trim_start_matches("@@") + .trim() + .to_string() + } else { + String::new() + }; + + ( + old_start as isize, + old_count as isize, + new_start as isize, + new_count as isize, + comment, + ) + } + + fn parse_hunk_header_range(part: &str) -> (usize, usize) { + let (old_start, old_count) = if part.contains(',') { + let old_parts: Vec<&str> = part.split(',').collect(); + ( + old_parts[0].parse().unwrap_or(0), + old_parts[1].parse().unwrap_or(0), + ) + } else { + (part.parse().unwrap_or(0), 1) + }; + (old_start, old_count) + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum PatchLine { + Context(String), + Addition(String), + Deletion(String), + HunkHeader(usize, usize, usize, usize, String), + FileStartMinus(String), + FileStartPlus(String), + Garbage(String), +} + +impl PatchLine { + pub fn parse(line: &str) -> Self { + if let Some(line) = line.strip_prefix("+") { + Self::Addition(line.to_string()) + } else if let Some(line) = line.strip_prefix("-") { + Self::Deletion(line.to_string()) + } else if let Some(line) = line.strip_prefix(" ") { + Self::Context(line.to_string()) + } else { + Self::Garbage(line.to_string()) + } + } +} + +impl ToString for PatchLine { + fn to_string(&self) -> String { + match self { + PatchLine::Context(line) => format!(" {}", line), + PatchLine::Addition(line) => format!("+{}", line), + PatchLine::Deletion(line) => format!("-{}", line), + PatchLine::HunkHeader(old_start, old_end, new_start, new_end, comment) => format!( + "@@ -{},{} +{},{} @@ {}", + old_start, old_end, new_start, new_end, comment + ) + .trim_end() + .into(), + PatchLine::FileStartMinus(filename) => format!("--- {}", filename), + PatchLine::FileStartPlus(filename) => format!("+++ {}", filename), + PatchLine::Garbage(line) => line.to_string(), + } + } +} + +/// +/// Removes specified edits from a patch by their indexes and adjusts line numbers accordingly. +/// +/// This function removes edits (additions and deletions) from the patch as they never were made. +/// The resulting patch is adjusted to maintain correctness. +/// +/// # Arguments +/// +/// * `patch` - A patch to modify +/// * `edit_indexes` - A vector of edit indexes to remove (0-based, counting only additions and deletions) +/// ``` +pub fn remove_edits(patch: &mut Patch, edit_indexes: Vec) { + let mut current_edit_index: isize = -1; + let mut new_start_delta_by_file: HashMap = HashMap::new(); + + for hunk in &mut patch.hunks { + if !hunk.is_filename_inherited { + new_start_delta_by_file.insert(hunk.filename.clone(), 0); + } + let delta = new_start_delta_by_file + .entry(hunk.filename.clone()) + .or_insert(0); + hunk.new_start += *delta; + + hunk.lines = hunk + .lines + .drain(..) + .filter_map(|line| { + let is_edit = matches!(line, PatchLine::Addition(_) | PatchLine::Deletion(_)); + if is_edit { + current_edit_index += 1; + if !edit_indexes.contains(&(current_edit_index as usize)) { + return Some(line); + } + } + match line { + PatchLine::Addition(_) => { + hunk.new_count -= 1; + *delta -= 1; + None + } + PatchLine::Deletion(content) => { + hunk.new_count += 1; + *delta += 1; + Some(PatchLine::Context(content)) + } + _ => Some(line), + } + }) + .collect(); + } + + patch.normalize_hunks(3); + patch.remove_empty_hunks(); +} + +/// +/// Apply specified edits in the patch. +/// +/// This generates another patch that looks like selected edits are already made +/// and became part of the context +/// +/// See also: `remove_edits()` +/// +pub fn apply_edits(patch: &mut Patch, edit_indexes: Vec) { + let mut current_edit_index: isize = -1; + let mut delta_by_file: HashMap = HashMap::new(); + + for hunk in &mut patch.hunks { + if !hunk.is_filename_inherited { + delta_by_file.insert(hunk.filename.clone(), 0); + } + let delta = delta_by_file.entry(hunk.filename.clone()).or_insert(0); + hunk.old_start += *delta; + + hunk.lines = hunk + .lines + .drain(..) + .filter_map(|line| { + let is_edit = matches!(line, PatchLine::Addition(_) | PatchLine::Deletion(_)); + if is_edit { + current_edit_index += 1; + if !edit_indexes.contains(&(current_edit_index as usize)) { + return Some(line); + } + } + match line { + PatchLine::Addition(content) => { + hunk.old_count += 1; + *delta += 1; + Some(PatchLine::Context(content)) + } + PatchLine::Deletion(_) => { + hunk.old_count -= 1; + *delta -= 1; + None + } + _ => Some(line), + } + }) + .collect(); + } + + patch.normalize_hunks(3); + patch.remove_empty_hunks(); +} + +/// Parse an order specification text into groups of edit indexes. +/// Supports numbers, ranges (a-b), commas, comments starting with `//`, and blank lines. +/// +/// # Example spec +/// +/// // Add new dependency +/// 1, 49 +/// +/// // Add new imports and types +/// 8-9, 51 +/// +/// // Add new struct and methods +/// 10-47 +/// +/// // Update tests +/// 48, 50 +/// +pub fn parse_order_spec(spec: &str) -> Vec> { + let mut order = Vec::new(); + + for line in spec.lines() { + let line = line.trim(); + + // Skip empty lines and comments + if line.is_empty() || line.starts_with("//") { + continue; + } + + // Parse the line into a BTreeSet + let mut set = BTreeSet::new(); + + for part in line.split(',') { + let part = part.trim(); + + if part.contains('-') { + // Handle ranges like "8-9" or "10-47" + let range_parts: Vec<&str> = part.split('-').collect(); + if range_parts.len() == 2 { + if let (Ok(start), Ok(end)) = ( + range_parts[0].parse::(), + range_parts[1].parse::(), + ) { + for i in start..=end { + set.insert(i); + } + } else { + eprintln!("Warning: Invalid range format '{}'", part); + } + } else { + eprintln!("Warning: Invalid range format '{}'", part); + } + } else { + // Handle single numbers + if let Ok(num) = part.parse::() { + set.insert(num); + } else { + eprintln!("Warning: Invalid number format '{}'", part); + } + } + } + + if !set.is_empty() { + order.push(set); + } + } + + order +} + +#[derive(Debug, Eq, PartialEq)] +pub struct EditLocation { + pub filename: String, + pub source_line_number: usize, + pub target_line_number: usize, + pub patch_line: PatchLine, + pub hunk_index: usize, + pub line_index_within_hunk: usize, +} + +#[derive(Debug, Eq, PartialEq)] +pub enum EditType { + Deletion, + Insertion, +} + +pub fn locate_edited_line(patch: &Patch, mut edit_index: isize) -> Option { + let mut edit_locations = vec![]; + + for (hunk_index, hunk) in patch.hunks.iter().enumerate() { + let mut old_line_number = hunk.old_start; + let mut new_line_number = hunk.new_start; + for (line_index, line) in hunk.lines.iter().enumerate() { + if matches!(line, PatchLine::Context(_)) { + old_line_number += 1; + new_line_number += 1; + continue; + } + + if !matches!(line, PatchLine::Addition(_) | PatchLine::Deletion(_)) { + continue; + } + + // old new + // 1 1 context + // 2 2 context + // 3 3 -deleted + // 4 3 +insert + // 4 4 more context + // + // old new + // 1 1 context + // 2 2 context + // 3 3 +inserted + // 3 4 more context + // + // old new + // 1 1 -deleted + // + // old new + // 1 1 context + // 2 2 context + // 3 3 -deleted + // 4 3 more context + + edit_locations.push(EditLocation { + filename: hunk.filename.clone(), + source_line_number: old_line_number as usize, + target_line_number: new_line_number as usize, + patch_line: line.clone(), + hunk_index, + line_index_within_hunk: line_index, + }); + + match line { + PatchLine::Addition(_) => new_line_number += 1, + PatchLine::Deletion(_) => old_line_number += 1, + PatchLine::Context(_) => (), + _ => (), + }; + } + } + + if edit_index < 0 { + edit_index += edit_locations.len() as isize; // take from end + } + (0..edit_locations.len()) + .contains(&(edit_index as usize)) + .then(|| edit_locations.swap_remove(edit_index as usize)) // remove to take ownership +} +// +// Helper function to count old and new lines +fn count_lines(lines: &[PatchLine]) -> (usize, usize) { + lines.iter().fold((0, 0), |(old, new), line| match line { + PatchLine::Context(_) => (old + 1, new + 1), + PatchLine::Deletion(_) => (old + 1, new), + PatchLine::Addition(_) => (old, new + 1), + _ => (old, new), + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use indoc::indoc; + + #[test] + fn test_parse_unified_diff() { + let patch_str = indoc! {" + Patch header + ============ + + diff --git a/text.txt b/text.txt + index 86c770d..a1fd855 100644 + --- a/text.txt + +++ b/text.txt + @@ -1,7 +1,7 @@ + azuere + beige + black + -blue + +dark blue + brown + cyan + gold + + Some garbage + + diff --git a/second.txt b/second.txt + index 86c770d..a1fd855 100644 + --- a/second.txt + +++ b/second.txt + @@ -9,6 +9,7 @@ gray + green + indigo + magenta + +silver + orange + pink + purple + diff --git a/text.txt b/text.txt + index 86c770d..a1fd855 100644 + --- a/text.txt + +++ b/text.txt + @@ -16,4 +17,3 @@ red + violet + white + yellow + -zinc + "}; + let patch = Patch::parse_unified_diff(patch_str); + + assert_eq!(patch.header, "Patch header\n============\n\n"); + assert_eq!(patch.hunks.len(), 3); + assert_eq!(patch.hunks[0].header_string(), "@@ -1,7 +1,7 @@"); + assert_eq!(patch.hunks[1].header_string(), "@@ -9,6 +9,7 @@ gray"); + assert_eq!(patch.hunks[2].header_string(), "@@ -16,4 +17,3 @@ red"); + assert_eq!(patch.hunks[0].is_filename_inherited, false); + assert_eq!(patch.hunks[1].is_filename_inherited, false); + assert_eq!(patch.hunks[2].is_filename_inherited, false); + } + + #[test] + fn test_locate_edited_line() { + let patch_str = indoc! {" + Patch header + ============ + + diff --git a/text.txt b/text.txt + index 86c770d..a1fd855 100644 + --- a/text.txt + +++ b/text.txt + @@ -1,7 +1,7 @@ + azuere + beige + black + -blue + +dark blue + brown + cyan + gold + diff --git a/second.txt b/second.txt + index 86c770d..a1fd855 100644 + --- a/second.txt + +++ b/second.txt + @@ -9,6 +9,7 @@ gray + green + indigo + magenta + +silver + orange + pink + purple + diff --git a/text.txt b/text.txt + index 86c770d..a1fd855 100644 + --- a/text.txt + +++ b/text.txt + @@ -16,4 +17,3 @@ red + violet + white + yellow + -zinc + "}; + let patch = Patch::parse_unified_diff(patch_str); + + assert_eq!( + locate_edited_line(&patch, 0), // -blue + Some(EditLocation { + filename: "text.txt".to_string(), + source_line_number: 4, + target_line_number: 4, + patch_line: PatchLine::Deletion("blue".to_string()), + hunk_index: 0, + line_index_within_hunk: 3 + }) + ); + assert_eq!( + locate_edited_line(&patch, 1), // +dark blue + Some(EditLocation { + filename: "text.txt".to_string(), + source_line_number: 5, + target_line_number: 4, + patch_line: PatchLine::Addition("dark blue".to_string()), + hunk_index: 0, + line_index_within_hunk: 4 + }) + ); + assert_eq!( + locate_edited_line(&patch, 2), // +silver + Some(EditLocation { + filename: "second.txt".to_string(), + source_line_number: 12, + target_line_number: 12, + patch_line: PatchLine::Addition("silver".to_string()), + hunk_index: 1, + line_index_within_hunk: 3 + }) + ); + } + + mod remove_edits { + use super::*; + use indoc::indoc; + use pretty_assertions::assert_eq; + + static PATCH: &'static str = indoc! {" + diff --git a/text.txt b/text.txt + index 86c770d..a1fd855 100644 + --- a/text.txt + +++ b/text.txt + @@ -1,7 +1,7 @@ + azuere + beige + black + -blue + +dark blue + brown + cyan + gold + @@ -9,6 +9,7 @@ gray + green + indigo + magenta + +silver + orange + pink + purple + @@ -16,4 +17,3 @@ red + violet + white + yellow + -zinc + "}; + + #[test] + fn test_removes_hunks_without_edits() { + // Remove the first two edits: + // -blue + // +dark blue + let mut patch = Patch::parse_unified_diff(PATCH); + remove_edits(&mut patch, vec![0, 1]); + + // The whole hunk should be removed since there are no other edits in it + let actual = patch.to_string(); + let expected = indoc! {" + --- a/text.txt + +++ b/text.txt + @@ -9,6 +9,7 @@ gray + green + indigo + magenta + +silver + orange + pink + purple + --- a/text.txt + +++ b/text.txt + @@ -16,4 +17,3 @@ red + violet + white + yellow + -zinc + "}; + assert_eq!(actual, String::from(expected)); + } + + #[test] + fn test_adjust_line_numbers_after_deletion() { + // Remove the first deletion (`-blue`) + let mut patch = Patch::parse_unified_diff(PATCH); + remove_edits(&mut patch, vec![0]); + + // The line numbers should be adjusted in the subsequent hunks + println!("{}", &patch.to_string()); + assert_eq!(patch.hunks[0].header_string(), "@@ -2,6 +2,7 @@"); + assert_eq!(patch.hunks[1].header_string(), "@@ -9,6 +10,7 @@ gray"); + assert_eq!(patch.hunks[2].header_string(), "@@ -16,4 +18,3 @@ red"); + } + #[test] + fn test_adjust_line_numbers_after_insertion() { + // Remove the first insertion (`+dark blue`) + let mut patch = Patch::parse_unified_diff(PATCH); + remove_edits(&mut patch, vec![1]); + + // The line numbers should be adjusted in the subsequent hunks + assert_eq!(patch.hunks[0].header_string(), "@@ -1,7 +1,6 @@"); + assert_eq!(patch.hunks[1].header_string(), "@@ -9,6 +8,7 @@ gray"); + assert_eq!(patch.hunks[2].header_string(), "@@ -16,4 +16,3 @@ red"); + } + #[test] + fn test_adjust_line_numbers_multifile_case() { + // Given a patch that spans multiple files + let patch_str = indoc! {" + --- a/first.txt + +++ b/first.txt + @@ -1,7 +1,7 @@ + azuere + beige + black + -blue + +dark blue + brown + cyan + gold + @@ -16,4 +17,3 @@ red + violet + white + yellow + -zinc + --- a/second.txt + +++ b/second.txt + @@ -9,6 +9,7 @@ gray + green + indigo + magenta + +silver + orange + pink + purple + "}; + + // When removing edit from one of the files (`+dark blue`) + let mut patch = Patch::parse_unified_diff(patch_str); + remove_edits(&mut patch, vec![1]); + + // Then the line numbers should only be adjusted in subsequent hunks from that file + assert_eq!(patch.hunks[0].header_string(), "@@ -1,7 +1,6 @@"); // edited hunk + assert_eq!(patch.hunks[1].header_string(), "@@ -16,4 +16,3 @@ red"); // hunk from edited file again + assert_eq!(patch.hunks[2].header_string(), "@@ -9,6 +9,7 @@ gray"); // hunk from another file + + // When removing hunk from `second.txt` + let mut patch = Patch::parse_unified_diff(patch_str); + remove_edits(&mut patch, vec![3]); + + // Then patch serialization should list `first.txt` only once + // (because hunks from that file become adjacent) + let expected = indoc! {" + --- a/first.txt + +++ b/first.txt + @@ -1,7 +1,7 @@ + azuere + beige + black + -blue + +dark blue + brown + cyan + gold + --- a/first.txt + +++ b/first.txt + @@ -16,4 +17,3 @@ red + violet + white + yellow + -zinc + "}; + assert_eq!(patch.to_string(), expected); + } + + #[test] + fn test_dont_adjust_line_numbers_samefile_case() { + // Given a patch that has hunks in the same file, but with a file header + // (which makes `git apply` flush edits so far and start counting lines numbers afresh) + let patch_str = indoc! {" + diff --git a/text.txt b/text.txt + index 86c770d..a1fd855 100644 + --- a/text.txt + +++ b/text.txt + @@ -1,7 +1,7 @@ + azuere + beige + black + -blue + +dark blue + brown + cyan + gold + --- a/text.txt + +++ b/text.txt + @@ -16,4 +16,3 @@ red + violet + white + yellow + -zinc + "}; + + // When removing edit from one of the files (`+dark blue`) + let mut patch = Patch::parse_unified_diff(patch_str); + remove_edits(&mut patch, vec![1]); + + // Then the line numbers should **not** be adjusted in a subsequent hunk, + // because it starts with a file header + assert_eq!(patch.hunks[0].header_string(), "@@ -1,7 +1,6 @@"); // edited hunk + assert_eq!(patch.hunks[1].header_string(), "@@ -16,4 +16,3 @@ red"); // subsequent hunk + } + } + + mod apply_edits { + use super::*; + use indoc::indoc; + use pretty_assertions::assert_eq; + + static PATCH: &'static str = indoc! {" + diff --git a/text.txt b/text.txt + index 86c770d..a1fd855 100644 + --- a/text.txt + +++ b/text.txt + @@ -1,7 +1,7 @@ + azuere + beige + black + -blue + +dark blue + brown + cyan + gold + --- a/text.txt + +++ b/text.txt + @@ -9,6 +9,7 @@ gray + green + indigo + magenta + +silver + orange + pink + purple + --- a/text.txt + +++ b/text.txt + @@ -16,4 +17,3 @@ red + violet + white + yellow + -zinc + "}; + + #[test] + fn test_removes_hunks_without_edits() { + // When applying the first two edits (`-blue`, `+dark blue`) + let mut patch = Patch::parse_unified_diff(PATCH); + apply_edits(&mut patch, vec![0, 1]); + + // Then the whole hunk should be removed since there are no other edits in it, + // and the line numbers should be adjusted in the subsequent hunks + assert_eq!(patch.hunks[0].header_string(), "@@ -9,6 +9,7 @@ gray"); + assert_eq!(patch.hunks[1].header_string(), "@@ -16,4 +17,3 @@ red"); + assert_eq!(patch.hunks.len(), 2); + } + + #[test] + fn test_adjust_line_numbers_after_applying_deletion() { + // Apply the first deletion (`-blue`) + let mut patch = Patch::parse_unified_diff(PATCH); + apply_edits(&mut patch, vec![0]); + + // The line numbers should be adjusted + assert_eq!(patch.hunks[0].header_string(), "@@ -1,6 +1,7 @@"); + assert_eq!(patch.hunks[1].header_string(), "@@ -8,6 +9,7 @@ gray"); + assert_eq!(patch.hunks[2].header_string(), "@@ -15,4 +17,3 @@ red"); + } + #[test] + fn test_adjust_line_numbers_after_applying_insertion() { + // Apply the first insertion (`+dark blue`) + let mut patch = Patch::parse_unified_diff(PATCH); + apply_edits(&mut patch, vec![1]); + + // The line numbers should be adjusted in the subsequent hunks + println!("{}", &patch.to_string()); + assert_eq!(patch.hunks[0].header_string(), "@@ -1,7 +1,6 @@"); + assert_eq!(patch.hunks[1].header_string(), "@@ -10,6 +9,7 @@ gray"); + assert_eq!(patch.hunks[2].header_string(), "@@ -17,4 +17,3 @@ red"); + } + } + + mod reorder_edits { + use super::*; + use indoc::indoc; + use pretty_assertions::assert_eq; + + static PATCH: &'static str = indoc! {" + Some header. + + diff --git a/first.txt b/first.txt + index 86c770d..a1fd855 100644 + --- a/first.txt + +++ b/first.txt + @@ -1,7 +1,7 @@ + azuere + beige + black + -blue + +dark blue + brown + cyan + gold + --- a/second.txt + +++ b/second.txt + @@ -9,6 +9,7 @@ gray + green + indigo + magenta + +silver + orange + pink + purple + --- a/first.txt + +++ b/first.txt + @@ -16,4 +17,3 @@ red + violet + white + yellow + -zinc + "}; + + #[test] + fn test_reorder_1() { + let edits_order = vec![ + BTreeSet::from([2]), // +silver + BTreeSet::from([3]), // -zinc + BTreeSet::from([0, 1]), // -blue +dark blue + ]; + + let patch = Patch::parse_unified_diff(PATCH); + let reordered_patch = reorder_edits(&patch, edits_order); + + // The whole hunk should be removed since there are no other edits in it + let actual = reordered_patch.to_string(); + + println!("{}", actual); + + let expected = indoc! {" + Some header. + + --- a/second.txt + +++ b/second.txt + @@ -9,6 +9,7 @@ gray + green + indigo + magenta + +silver + orange + pink + purple + --- a/first.txt + +++ b/first.txt + @@ -16,4 +17,3 @@ red + violet + white + yellow + -zinc + --- a/first.txt + +++ b/first.txt + @@ -1,7 +1,7 @@ + azuere + beige + black + -blue + +dark blue + brown + cyan + gold + "}; + assert_eq!(actual, String::from(expected)); + } + + #[test] + fn test_reorder_duplicates() { + let edits_order = vec![ + BTreeSet::from([2]), // +silver + BTreeSet::from([2]), // +silver again + BTreeSet::from([3]), // -zinc + ]; + + let patch = Patch::parse_unified_diff(PATCH); + let reordered_patch = reorder_edits(&patch, edits_order); + + // The whole hunk should be removed since there are no other edits in it + let actual = reordered_patch.to_string(); + + println!("{}", actual); + + let expected = indoc! {" + Some header. + + --- a/second.txt + +++ b/second.txt + @@ -9,6 +9,7 @@ gray + green + indigo + magenta + +silver + orange + pink + purple + --- a/first.txt + +++ b/first.txt + @@ -16,4 +17,3 @@ red + violet + white + yellow + -zinc + "}; + assert_eq!(actual, String::from(expected)); + } + } + + mod extract_edits { + + use super::*; + use indoc::indoc; + use pretty_assertions::assert_eq; + + static PATCH: &'static str = indoc! {" + Some header. + + diff --git a/first.txt b/first.txt + index 86c770d..a1fd855 100644 + --- a/first.txt + +++ b/first.txt + @@ -1,7 +1,7 @@ + azuere + beige + black + -blue + +dark blue + brown + cyan + gold + @@ -16,4 +17,3 @@ red + violet + white + yellow + -zinc + --- a/second.txt + +++ b/second.txt + @@ -9,6 +9,7 @@ gray + green + indigo + magenta + +silver + orange + pink + purple + "}; + + #[test] + fn test_extract_edits() { + let to_extract = BTreeSet::from([ + 3, // +silver + 0, // -blue + ]); + + let mut patch = Patch::parse_unified_diff(PATCH); + let (extracted, remainder) = extract_edits(&mut patch, &to_extract); + + // Edits will be extracted in the sorted order, so [0, 3] + let expected_extracted = indoc! {" + Some header. + + --- a/first.txt + +++ b/first.txt + @@ -1,7 +1,6 @@ + azuere + beige + black + -blue + brown + cyan + gold + --- a/second.txt + +++ b/second.txt + @@ -9,6 +9,7 @@ gray + green + indigo + magenta + +silver + orange + pink + purple + "}; + + let expected_remainder = indoc! {" + Some header. + + --- a/first.txt + +++ b/first.txt + @@ -1,6 +1,7 @@ + azuere + beige + black + +dark blue + brown + cyan + gold + --- a/first.txt + +++ b/first.txt + @@ -15,4 +17,3 @@ red + violet + white + yellow + -zinc + "}; + assert_eq!(extracted.to_string(), String::from(expected_extracted)); + assert_eq!(remainder.to_string(), String::from(expected_remainder)); + } + } + + #[test] + fn test_parse_order_file() { + let content = r#" +// Add new dependency +1, 49 + +// Add new imports and types +8-9, 51 + +// Add new struct and login command method +10-47 + +// Modify AgentServerDelegate to make status_tx optional +2-3 + +// Update status_tx usage to handle optional value +4 +5-7 + +// Update all existing callers to use None for status_tx +48, 50 + +// Update the main login implementation to use custom command +52-55 +56-95 +"#; + + let order = parse_order_spec(content); + + assert_eq!(order.len(), 9); + + // First group: 1, 49 + assert_eq!(order[0], BTreeSet::from([1, 49])); + + // Second group: 8-9, 51 + assert_eq!(order[1], BTreeSet::from([8, 9, 51])); + + // Third group: 10-47 + let expected_range: BTreeSet = (10..=47).collect(); + assert_eq!(order[2], expected_range); + + // Fourth group: 2-3 + assert_eq!(order[3], BTreeSet::from([2, 3])); + + // Fifth group: 4 + assert_eq!(order[4], BTreeSet::from([4])); + + // Sixth group: 5-7 + assert_eq!(order[5], BTreeSet::from([5, 6, 7])); + + // Seventh group: 48, 50 + assert_eq!(order[6], BTreeSet::from([48, 50])); + + // Eighth group: 52-55 + assert_eq!(order[7], BTreeSet::from([52, 53, 54, 55])); + + // Ninth group: 56-95 + let expected_range_2: BTreeSet = (56..=95).collect(); + assert_eq!(order[8], expected_range_2); + } + + #[test] + fn test_normalize_hunk() { + let mut patch = Patch::parse_unified_diff(indoc! {" + This patch has too many lines of context. + + --- a/first.txt + +++ b/first.txt + @@ -1,7 +1,6 @@ + azuere + beige + black + -blue + brown + cyan + gold + // Some garbage + "}); + + patch.normalize_hunks(1); + let actual = patch.to_string(); + assert_eq!( + actual, + indoc! {" + This patch has too many lines of context. + + --- a/first.txt + +++ b/first.txt + @@ -3,3 +3,2 @@ + black + -blue + brown + // Some garbage + "} + ); + } +} diff --git a/crates/edit_prediction_cli/src/split_commit.rs b/crates/edit_prediction_cli/src/split_commit.rs new file mode 100644 index 0000000000000000000000000000000000000000..88be74511901a7704fdaa3934a01ad20abe3b032 --- /dev/null +++ b/crates/edit_prediction_cli/src/split_commit.rs @@ -0,0 +1,1465 @@ +//! `ep split-commit` implementation. +//! +//! This command generates a single evaluation example JSON object from a +//! chronologically-ordered unified diff (a "commit"). +//! +//! TODO: Port Python code to generate chronologically-ordered commits +use crate::reorder_patch::{Patch, PatchLine, extract_edits, locate_edited_line}; +use anyhow::{Context as _, Result}; +use clap::Args; +use rand::Rng; +use rand::SeedableRng; +use serde::{Deserialize, Serialize}; +use similar::{DiffTag, TextDiff}; +use std::collections::BTreeSet; +use std::fs; +use std::io::{self, Read}; + +/// `ep split-commit` CLI args. +#[derive(Debug, Args)] +pub struct SplitCommitArgs { + /// Path to the commit file (use "-" for stdin) + #[arg(long, short = 'c')] + pub commit: String, + + /// Repository URL + #[arg(long, short = 'r', default_value_t = String::new())] + pub repository_url: String, + + /// Commit hash + #[arg(long, default_value_t = String::new())] + pub commit_hash: String, + + /// Split point (float 0.0-1.0 for fraction, or integer for index) + #[arg(long, short = 's')] + pub split_point: Option, + + /// Random seed for reproducibility + #[arg(long)] + pub seed: Option, + + /// Pretty-print JSON output + #[arg(long, short = 'p')] + pub pretty: bool, +} + +/// Cursor position in a file. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct CursorPosition { + pub file: String, + pub line: usize, + pub column: usize, +} + +impl std::fmt::Display for CursorPosition { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}:{}:{}", self.file, self.line, self.column) + } +} + +/// Represents a split commit with source and target patches. +#[derive(Debug, Clone)] +pub struct SplitCommit { + pub source_patch: String, + pub target_patch: String, +} + +/// The evaluation case structure that will be serialized to JSON. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EvaluationCase { + pub repository_url: String, + pub commit: String, + pub edit_history: Vec, + pub cursor_position: String, + pub cursor_excerpt: String, + pub expected_hunks: Vec, + pub expected_patch: String, + pub allowed_patch: String, + pub expected_context_excerpts: Vec, + pub extra: serde_json::Value, +} + +/// Split point specification for evaluation generation. +#[derive(Debug, Clone)] +pub enum SplitPoint { + /// Fraction of total edits (0.0 to 1.0) + Fraction(f64), + /// Absolute index + Index(usize), +} + +fn parse_split_point(value: &str) -> Option { + if value.contains('.') { + value.parse::().ok().map(SplitPoint::Fraction) + } else { + value.parse::().ok().map(SplitPoint::Index) + } +} + +/// Entry point for the `ep split-commit` subcommand. +/// +/// This runs synchronously and prints a single JSON object to stdout. +pub fn run_split_commit(args: &SplitCommitArgs) -> Result<()> { + let commit = if args.commit == "-" { + let mut content = String::new(); + io::stdin() + .read_to_string(&mut content) + .context("failed to read commit diff from stdin")?; + content + } else { + fs::read_to_string(&args.commit) + .with_context(|| format!("failed to read commit diff from {}", args.commit))? + }; + + let split_point = args.split_point.as_deref().and_then(parse_split_point); + + let case = generate_evaluation_example_from_ordered_commit( + &commit, + &args.repository_url, + &args.commit_hash, + split_point, + args.seed, + ) + .context("failed to generate evaluation example")?; + + let json = if args.pretty { + serde_json::to_string_pretty(&case) + } else { + serde_json::to_string(&case) + } + .context("failed to serialize evaluation case as JSON")?; + + println!("{json}"); + Ok(()) +} + +/// Main function to generate an evaluation example from an ordered commit. +/// +/// # Arguments +/// * `commit` - Chronologically ordered unified diff of the commit +/// * `repository_url` - URL of the repository +/// * `commit_hash` - Hash of the commit +/// * `split_point` - Point at which the commit will be split (None for random) +/// * `seed` - Optional seed for randomness +pub fn generate_evaluation_example_from_ordered_commit( + commit: &str, + repository_url: &str, + commit_hash: &str, + split_point: Option, + seed: Option, +) -> Result { + let mut rng: Box = match seed { + Some(seed) => Box::new(rand::rngs::StdRng::seed_from_u64(seed)), + None => Box::new(rand::rngs::ThreadRng::default()), + }; + + // Parse and normalize the commit + let mut patch = Patch::parse_unified_diff(commit); + + // Filter header to only keep lines starting with "//" + let header_lines: Vec<&str> = patch + .header + .lines() + .filter(|line| line.starts_with("//")) + .collect(); + patch.header = if header_lines.is_empty() { + String::new() + } else { + header_lines.join("\n") + "\n" + }; + let commit_normalized = patch.to_string(); + + // Compute the split point + let stats = patch.stats(); + let num_edits = stats.added + stats.removed; + + anyhow::ensure!(num_edits != 0, "no edits found in commit"); + + let split = match split_point { + None => rng.random_range(1..=num_edits), + Some(SplitPoint::Fraction(f)) => { + let v = (f * num_edits as f64).floor() as usize; + v.min(num_edits) + } + Some(SplitPoint::Index(i)) => i.min(num_edits), + }; + + // Split the commit into source and target patches + let (prefix, suffix) = split_ordered_commit(&commit_normalized, split); + + let mut split_commit = SplitCommit { + source_patch: prefix, + target_patch: suffix, + }; + + // Imitate human edits + let human_edit_seed = rng.random_range(1..=10000u64); + let (src_patch, tgt_patch, cursor_opt) = imitate_human_edits( + &split_commit.source_patch, + &split_commit.target_patch, + human_edit_seed, + ); + split_commit.source_patch = src_patch; + split_commit.target_patch = tgt_patch; + + // Sample cursor position + let cursor = match cursor_opt { + Some(c) => c, + None => sample_cursor_position(&patch, &split_commit) + .context("failed to sample cursor position")?, + }; + + // Get cursor excerpt + let cursor_excerpt = get_cursor_excerpt( + &cursor, + &split_commit.source_patch, + &split_commit.target_patch, + ) + .context("failed to generate cursor excerpt")?; + + // Handle edge case where split_point == 0 + if split == 0 { + split_commit.target_patch = String::new(); + } + + Ok(EvaluationCase { + repository_url: repository_url.to_string(), + commit: format!("{}~1", commit_hash), + edit_history: vec![split_commit.source_patch.clone()], + cursor_position: cursor.to_string(), + cursor_excerpt, + expected_hunks: vec![split_commit.target_patch.clone()], + expected_patch: split_commit.target_patch.clone(), + allowed_patch: split_commit.target_patch, + expected_context_excerpts: vec![], + extra: serde_json::json!({}), + }) +} + +/// Split an ordered commit into source and target commits. +/// +/// # Arguments +/// * `commit` - Ordered commit string +/// * `split_pos` - Position to split the commit (number of edited lines) +/// +/// # Returns +/// A tuple of (source_diff, target_diff) +pub fn split_ordered_commit(commit: &str, split_pos: usize) -> (String, String) { + let patch = Patch::parse_unified_diff(commit); + let source_edits: BTreeSet = (0..split_pos).collect(); + let (source, target) = extract_edits(&patch, &source_edits); + + let mut source_str = source.to_string(); + let target_str = target.to_string(); + + // Strip last group header from the source (lines starting with "//" at the end) + let source_lines: Vec<&str> = source_str.lines().collect(); + let mut end_idx = source_lines.len(); + for i in (0..source_lines.len()).rev() { + if source_lines[i].starts_with("//") { + end_idx = i; + } else { + break; + } + } + if end_idx < source_lines.len() { + source_str = source_lines[..end_idx].join("\n"); + if !source_str.is_empty() { + source_str.push('\n'); + } + } + + (source_str, target_str) +} + +/// Tokenize text into words and non-word characters. +fn tokenize(text: &str) -> Vec { + let mut tokens = Vec::new(); + let mut current = String::new(); + + for ch in text.chars() { + if ch.is_alphanumeric() { + current.push(ch); + } else if ch == '_' { + // Include underscore with the current word, then flush + current.push(ch); + if !current.is_empty() { + tokens.push(std::mem::take(&mut current)); + } + } else { + // Punctuation or whitespace - flush current word first + if !current.is_empty() { + tokens.push(std::mem::take(&mut current)); + } + // Each punctuation/whitespace is its own token + tokens.push(ch.to_string()); + } + } + + if !current.is_empty() { + tokens.push(current); + } + + tokens +} + +/// Calculate the weight for a split position based on the character at that position. +/// +/// Higher weights indicate more natural pause points (e.g., after punctuation, +/// at identifier boundaries). Lower weights indicate less natural points +/// (e.g., mid-identifier). +fn position_weight(text: &str, pos: usize) -> u32 { + if pos == 0 || pos > text.len() { + return 1; + } + + let chars: Vec = text.chars().collect(); + if pos > chars.len() { + return 1; + } + + // Get the character just before this position (what we just "typed") + let prev_char = chars[pos - 1]; + + // High weight: natural pause points (end of statement/argument, opening brackets) + if matches!(prev_char, ',' | ';' | ':' | '(' | '[' | '{') { + return 10; + } + + // High weight: closing brackets (finished a group) + if matches!(prev_char, ')' | ']' | '}') { + return 8; + } + + // Medium weight: operators and method chains + if matches!( + prev_char, + '.' | '+' | '-' | '*' | '/' | '=' | '<' | '>' | '&' | '|' | '!' + ) { + return 5; + } + + // Check if we're at the end of an identifier (word char followed by non-word char) + let is_prev_word_char = prev_char.is_alphanumeric() || prev_char == '_'; + let is_next_word_char = + pos < chars.len() && (chars[pos].is_alphanumeric() || chars[pos] == '_'); + + if is_prev_word_char && !is_next_word_char { + // End of identifier - high weight + return 8; + } + + // Whitespace is a natural pause + if prev_char.is_whitespace() { + return 6; + } + + // Mid-identifier: low weight (rare autocomplete scenarios) + if is_prev_word_char && is_next_word_char { + return 1; + } + + // Default medium-low weight + 3 +} + +/// Select a weighted random index from a list of weights. +/// +/// Returns an index based on the weights, using the provided seed for +/// deterministic selection. +fn weighted_select(weights: &[u32], seed: u64) -> usize { + if weights.is_empty() { + return 0; + } + + let total_weight: u64 = weights.iter().map(|&w| w as u64).sum(); + if total_weight == 0 { + // Fallback to uniform selection if all weights are zero + return seed as usize % weights.len(); + } + + // Use seed to select a value in [0, total_weight) + let target = seed % total_weight; + let mut cumulative: u64 = 0; + + for (idx, &weight) in weights.iter().enumerate() { + cumulative += weight as u64; + if target < cumulative { + return idx; + } + } + + // Fallback to last index + weights.len() - 1 +} + +/// Calculate similarity ratio between two strings (0-100). +fn fuzzy_ratio(s1: &str, s2: &str) -> u32 { + if s1.is_empty() && s2.is_empty() { + return 100; + } + if s1.is_empty() || s2.is_empty() { + return 0; + } + + let diff = TextDiff::from_chars(s1, s2); + let matching: usize = diff + .ops() + .iter() + .filter_map(|op| { + if matches!(op.tag(), DiffTag::Equal) { + Some(op.new_range().len()) + } else { + None + } + }) + .sum(); + + let total = s1.len() + s2.len(); + ((2 * matching * 100) / total) as u32 +} + +/// Imitate human edits by introducing partial line edits. +/// +/// This function simulates how a human might incrementally type code, +/// rather than making complete line replacements. +pub fn imitate_human_edits( + source_patch: &str, + target_patch: &str, + seed: u64, +) -> (String, String, Option) { + let no_change = (source_patch.to_string(), target_patch.to_string(), None); + + let src_patch = Patch::parse_unified_diff(source_patch); + let tgt_patch = Patch::parse_unified_diff(target_patch); + + if tgt_patch.hunks.is_empty() { + return no_change; + } + + // Try to locate the first edit in target + let tgt_edit_loc = match locate_edited_line(&tgt_patch, 0) { + Some(loc) => loc, + None => return no_change, + }; + + let tgt_is_addition = matches!(tgt_edit_loc.patch_line, PatchLine::Addition(_)); + if !tgt_is_addition { + return no_change; + } + + let tgt_line = match &tgt_edit_loc.patch_line { + PatchLine::Addition(s) => s.clone(), + _ => return no_change, + }; + + // Try to locate the last edit in source + let src_edit_loc = locate_edited_line(&src_patch, -1); + + // Check if source has ANY edit at the same line as target's first edit + // We need to iterate through all edits to check this + let src_has_edit_at_target_line = { + let mut found = false; + let mut idx = 0isize; + while let Some(loc) = locate_edited_line(&src_patch, idx) { + if loc.filename == tgt_edit_loc.filename + && loc.target_line_number == tgt_edit_loc.target_line_number + { + found = true; + break; + } + idx += 1; + } + found + }; + + // Check if this is a replacement (deletion followed by insertion on the same line) + // or a pure insertion (no corresponding deletion in source) + let is_replacement = src_edit_loc.as_ref().map_or(false, |loc| { + matches!(loc.patch_line, PatchLine::Deletion(_)) + && loc.filename == tgt_edit_loc.filename + && loc.target_line_number == tgt_edit_loc.target_line_number + }); + + // If source has an edit at the same line but it's not a replacement (i.e., it's an addition), + // we shouldn't process this as a pure insertion either + if !is_replacement && src_has_edit_at_target_line { + return no_change; + } + + let src_line = if is_replacement { + match &src_edit_loc.as_ref().unwrap().patch_line { + PatchLine::Deletion(s) => s.clone(), + _ => return no_change, + } + } else { + // Pure insertion: source line is empty + String::new() + }; + + // Don't process if source and target are the same + if src_line == tgt_line { + return no_change; + } + + // Tokenize both lines + let src_tokens = tokenize(&src_line); + let tgt_tokens = tokenize(&tgt_line); + + // Convert to slices for similar + let src_refs: Vec<&str> = src_tokens.iter().map(|s| s.as_str()).collect(); + let tgt_refs: Vec<&str> = tgt_tokens.iter().map(|s| s.as_str()).collect(); + + // Use similar to get diff operations + let diff = TextDiff::from_slices(&src_refs, &tgt_refs); + + // Build weights for each possible split position + let mut position_weights: Vec = Vec::new(); + + // Simulate the edit process to collect weights for all possible split positions + { + let mut current_text = String::new(); + + for op in diff.ops() { + match op.tag() { + DiffTag::Equal => { + for i in op.old_range() { + current_text.push_str(&src_tokens[i]); + } + } + DiffTag::Replace => { + let ins: String = op.new_range().map(|i| tgt_tokens[i].as_str()).collect(); + let del: String = op.old_range().map(|i| src_tokens[i].as_str()).collect(); + + // For insertion part + for ch in ins.chars() { + current_text.push(ch); + let weight = position_weight(¤t_text, current_text.len()); + position_weights.push(weight); + } + + // For deletion part (we're "untyping" from source) + for _ in del.chars() { + // Weight deletions lower as they represent removing text + position_weights.push(2); + } + } + DiffTag::Insert => { + let ins: String = op.new_range().map(|i| tgt_tokens[i].as_str()).collect(); + for ch in ins.chars() { + current_text.push(ch); + let weight = position_weight(¤t_text, current_text.len()); + position_weights.push(weight); + } + } + DiffTag::Delete => { + let del: String = op.old_range().map(|i| src_tokens[i].as_str()).collect(); + for _ in del.chars() { + // Weight deletions lower + position_weights.push(2); + } + } + } + } + } + + // Use weighted selection to choose split index + if position_weights.is_empty() { + return no_change; + } + let split_index = weighted_select(&position_weights, seed); + + let mut edit_index = 0usize; + let mut new_src = String::new(); + let mut split_found = false; + let mut last_old_end = 0usize; + + for op in diff.ops() { + match op.tag() { + DiffTag::Equal => { + for i in op.old_range() { + new_src.push_str(&src_tokens[i]); + } + last_old_end = op.old_range().end; + } + DiffTag::Replace => { + // Handle replace as delete + insert + let del: String = op.old_range().map(|i| src_tokens[i].as_str()).collect(); + let ins: String = op.new_range().map(|i| tgt_tokens[i].as_str()).collect(); + let repl_len = del.len() + ins.len(); + if edit_index + repl_len >= split_index { + // Split within this replace operation + let offset = split_index - edit_index; + if offset < ins.len() { + new_src.push_str(&ins[..offset]); + } else { + new_src.push_str(&ins); + let del_offset = offset - ins.len(); + new_src.push_str(&del[..del_offset.min(del.len())]); + } + split_found = true; + last_old_end = op.old_range().end; + break; + } else { + edit_index += repl_len; + new_src.push_str(&ins); + last_old_end = op.old_range().end; + } + } + DiffTag::Insert => { + let repl: String = op.new_range().map(|i| tgt_tokens[i].as_str()).collect(); + if edit_index + repl.len() >= split_index { + let offset = split_index - edit_index; + new_src.push_str(&repl[..offset]); + split_found = true; + break; + } else { + edit_index += repl.len(); + new_src.push_str(&repl); + } + } + DiffTag::Delete => { + let repl: String = op.old_range().map(|i| src_tokens[i].as_str()).collect(); + if edit_index + repl.len() >= split_index { + let offset = split_index - edit_index; + new_src.push_str(&repl[..offset]); + split_found = true; + last_old_end = op.old_range().start + offset.min(op.old_range().len()); + break; + } else { + edit_index += repl.len(); + new_src.push_str(&repl); + last_old_end = op.old_range().end; + } + } + } + } + + if !split_found { + return no_change; + } + + // Calculate cursor position + let cursor = CursorPosition { + file: tgt_edit_loc.filename.clone(), + line: if is_replacement { + src_edit_loc.as_ref().unwrap().source_line_number + } else { + tgt_edit_loc.target_line_number + }, + column: new_src.len() + 1, + }; + + // Add remainder of source if similar enough to target remainder + let remainder_src: String = (last_old_end..src_tokens.len()) + .map(|i| src_tokens[i].as_str()) + .collect(); + let remainder_tgt: String = (last_old_end..tgt_tokens.len()) + .filter_map(|i| tgt_tokens.get(i).map(|s| s.as_str())) + .collect(); + + let ratio = fuzzy_ratio(&remainder_src, &remainder_tgt); + if ratio > 35 { + new_src.push_str(&remainder_src); + } + + if new_src.trim().is_empty() { + return no_change; + } + + if new_src == src_line { + return no_change; + } + + // Build new source patch with the intermediate line + let mut new_src_patch = src_patch; + if is_replacement { + // For replacements, insert after the deletion line + let src_loc = src_edit_loc.as_ref().unwrap(); + if let Some(hunk) = new_src_patch.hunks.get_mut(src_loc.hunk_index) { + hunk.lines.insert( + src_loc.line_index_within_hunk + 1, + PatchLine::Addition(new_src.clone()), + ); + hunk.new_count += 1; + } + } else { + // For pure insertions, we need to add or modify a hunk + if let Some(hunk) = new_src_patch.hunks.get_mut(tgt_edit_loc.hunk_index) { + // Insert the partial line at the same position as target + hunk.lines.insert( + tgt_edit_loc.line_index_within_hunk, + PatchLine::Addition(new_src.clone()), + ); + hunk.new_count += 1; + } else if new_src_patch.hunks.is_empty() { + // Source patch is empty, create a new hunk based on target + if let Some(tgt_hunk) = tgt_patch.hunks.get(tgt_edit_loc.hunk_index) { + let mut new_hunk = tgt_hunk.clone(); + // Replace the full addition with the partial one + new_hunk.lines.clear(); + for (i, line) in tgt_hunk.lines.iter().enumerate() { + if i == tgt_edit_loc.line_index_within_hunk { + new_hunk.lines.push(PatchLine::Addition(new_src.clone())); + } else { + match line { + PatchLine::Addition(_) => { + // Skip other additions from target + } + _ => new_hunk.lines.push(line.clone()), + } + } + } + new_hunk.new_count = new_hunk.old_count + 1; + new_src_patch.hunks.push(new_hunk); + // Copy header from target if source doesn't have one + if new_src_patch.header.is_empty() { + new_src_patch.header = tgt_patch.header.clone(); + } + } + } + } + + // Build new target patch with the intermediate line as deletion + let mut new_tgt_patch = tgt_patch; + if let Some(hunk) = new_tgt_patch.hunks.get_mut(tgt_edit_loc.hunk_index) { + hunk.lines.insert( + tgt_edit_loc.line_index_within_hunk, + PatchLine::Deletion(new_src), + ); + hunk.old_count += 1; + } + + ( + new_src_patch.to_string(), + new_tgt_patch.to_string(), + Some(cursor), + ) +} + +/// Locate the end of the last edit in a patch. +fn locate_end_of_last_edit(patch: &Patch) -> Option { + let loc = locate_edited_line(patch, -1)?; + + let (line, col) = match &loc.patch_line { + PatchLine::Addition(content) => (loc.target_line_number, content.len()), + PatchLine::Deletion(_) => (loc.target_line_number, 1), + _ => return None, + }; + + Some(CursorPosition { + file: loc.filename, + line, + column: col, + }) +} + +/// Locate the beginning of the first edit in a patch. +fn locate_beginning_of_first_edit(patch: &Patch) -> Option { + let loc = locate_edited_line(patch, 0)?; + + let hunk = patch.hunks.get(loc.hunk_index)?; + let column = if loc.line_index_within_hunk > 0 { + if let Some(prev_line) = hunk.lines.get(loc.line_index_within_hunk - 1) { + let content = match prev_line { + PatchLine::Context(s) | PatchLine::Addition(s) | PatchLine::Deletion(s) => s, + _ => return None, + }; + content.len().max(1) - 1 + } else { + 0 + } + } else { + 0 + }; + + let line = loc.target_line_number.saturating_sub(1).max(1); + + Some(CursorPosition { + file: loc.filename, + line, + column, + }) +} + +/// Sample cursor position according to the following rules: +/// 1. 50% chance of cursor being at the end of the source patch +/// 2. 50% chance of cursor being at the beginning of the target patch +pub fn sample_cursor_position(patch: &Patch, split_commit: &SplitCommit) -> Option { + // Try end of history first + let src_patch = Patch::parse_unified_diff(&split_commit.source_patch); + if let Some(cursor) = locate_end_of_last_edit(&src_patch) { + return Some(cursor); + } + + // Try beginning of target + let tgt_patch = Patch::parse_unified_diff(&split_commit.target_patch); + if let Some(cursor) = locate_beginning_of_first_edit(&tgt_patch) { + return Some(cursor); + } + + // Fallback: use the original patch + locate_end_of_last_edit(patch) +} + +/// Get cursor excerpt from the patches. +/// +/// This extracts the lines around the cursor position with a cursor marker. +pub fn get_cursor_excerpt( + cursor: &CursorPosition, + source_patch: &str, + target_patch: &str, +) -> Option { + let mut excerpt_lines: Vec = Vec::new(); + let mut excerpt_first_line: usize = 0; + + // Search in the last hunk of source patch + let src = Patch::parse_unified_diff(source_patch); + if let Some(loc) = locate_edited_line(&src, -1) { + if loc.filename == cursor.file && loc.target_line_number == cursor.line { + if let Some(hunk) = src.hunks.get(loc.hunk_index) { + excerpt_first_line = hunk.new_start as usize; + for line in &hunk.lines { + match line { + PatchLine::Addition(s) | PatchLine::Context(s) => { + excerpt_lines.push(s.clone()); + } + _ => {} + } + } + } + } + } + + // Search in target patch if not found + if excerpt_lines.is_empty() { + let tgt = Patch::parse_unified_diff(target_patch); + if let Some(loc) = locate_edited_line(&tgt, 0) { + if loc.filename == cursor.file { + if let Some(hunk) = tgt.hunks.get(loc.hunk_index) { + excerpt_first_line = hunk.new_start as usize; + for line in &hunk.lines { + match line { + PatchLine::Deletion(s) | PatchLine::Context(s) => { + excerpt_lines.push(s.clone()); + } + _ => {} + } + } + } + } + } + } + + if excerpt_lines.is_empty() { + return None; + } + + // Add cursor marker + for (i, line) in excerpt_lines.iter_mut().enumerate() { + let line_num = excerpt_first_line + i; + if line_num == cursor.line { + let col = cursor.column.min(line.len()); + let (before, after) = line.split_at(col); + *line = format!("{}<|user_cursor|>{}", before, after); + break; + } + } + + Some(excerpt_lines.join("\n")) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_tokenize() { + let tokens = tokenize("hello world"); + assert_eq!(tokens, vec!["hello", " ", "world"]); + + let tokens = tokenize("foo_bar123 + baz"); + assert_eq!(tokens, vec!["foo_", "bar123", " ", "+", " ", "baz"]); + + let tokens = tokenize("print(\"hello\")"); + assert_eq!(tokens, vec!["print", "(", "\"", "hello", "\"", ")"]); + + let tokens = tokenize("hello_world"); + assert_eq!(tokens, vec!["hello_", "world"]); + + let tokens = tokenize("fn();"); + assert_eq!(tokens, vec!["fn", "(", ")", ";"]); + } + + #[test] + fn test_fuzzy_ratio() { + assert_eq!(fuzzy_ratio("hello", "hello"), 100); + assert_eq!(fuzzy_ratio("", ""), 100); + assert!(fuzzy_ratio("hello", "world") < 50); + assert!(fuzzy_ratio("hello world", "hello worl") > 80); + } + + #[test] + fn test_split_ordered_commit() { + let commit = r#"// First change +--- a/test.rs ++++ b/test.rs +@@ -1,3 +1,4 @@ + fn main() { ++ println!("hello"); ++ println!("world"); + } +"#; + let patch = Patch::parse_unified_diff(commit); + let stats = patch.stats(); + assert_eq!(stats.added, 2); + + let (source, target) = split_ordered_commit(commit, 1); + + // Source should have 1 addition + let src_patch = Patch::parse_unified_diff(&source); + assert_eq!(src_patch.stats().added, 1); + + // Target should have 1 addition + let tgt_patch = Patch::parse_unified_diff(&target); + assert_eq!(tgt_patch.stats().added, 1); + } + + #[test] + fn test_split_ordered_commit_with_deletions() { + let commit = r#"// Change +--- a/test.rs ++++ b/test.rs +@@ -1,3 +1,3 @@ + fn main() { +- println!("old"); ++ println!("new"); + } +"#; + let patch = Patch::parse_unified_diff(commit); + let stats = patch.stats(); + assert_eq!(stats.added, 1); + assert_eq!(stats.removed, 1); + + // Split at position 1 (after the deletion) + let (source, target) = split_ordered_commit(commit, 1); + + let src_patch = Patch::parse_unified_diff(&source); + let tgt_patch = Patch::parse_unified_diff(&target); + + // Source should have the deletion + assert_eq!(src_patch.stats().removed, 1); + // Target should have the addition + assert_eq!(tgt_patch.stats().added, 1); + } + + #[test] + fn test_generate_evaluation_example() { + let commit = r#"commit abc123 +Author: Test +Date: Mon Jan 1 00:00:00 2024 + + Test commit + +//////////////////////////////////////////////////////////////////////////////// +// Add greeting +//////////////////////////////////////////////////////////////////////////////// +--- a/test.rs ++++ b/test.rs +@@ -1,3 +1,5 @@ + fn main() { ++ println!("hello"); ++ println!("world"); + } +"#; + + let result = generate_evaluation_example_from_ordered_commit( + commit, + "https://github.com/test/repo", + "abc123", + Some(SplitPoint::Fraction(0.5)), + Some(42), + ); + + assert!(result.is_ok()); + let case = result.unwrap(); + assert_eq!(case.repository_url, "https://github.com/test/repo"); + assert_eq!(case.commit, "abc123~1"); + assert!(!case.edit_history.is_empty()); + } + + #[test] + fn test_generate_evaluation_example_reproducible() { + let commit = r#"//////////////////////////////////////////////////////////////////////////////// +// Add greeting +//////////////////////////////////////////////////////////////////////////////// +--- a/test.rs ++++ b/test.rs +@@ -1,3 +1,5 @@ + fn main() { ++ println!("hello"); ++ println!("world"); + } +"#; + + // Run twice with the same seed + let result1 = generate_evaluation_example_from_ordered_commit( + commit, + "https://github.com/test/repo", + "abc123", + Some(SplitPoint::Fraction(0.5)), + Some(12345), + ) + .unwrap(); + + let result2 = generate_evaluation_example_from_ordered_commit( + commit, + "https://github.com/test/repo", + "abc123", + Some(SplitPoint::Fraction(0.5)), + Some(12345), + ) + .unwrap(); + + // Results should be identical + assert_eq!(result1.edit_history, result2.edit_history); + assert_eq!(result1.expected_patch, result2.expected_patch); + assert_eq!(result1.cursor_position, result2.cursor_position); + } + + #[test] + fn test_cursor_position_display() { + let cursor = CursorPosition { + file: "src/main.rs".to_string(), + line: 42, + column: 10, + }; + assert_eq!(cursor.to_string(), "src/main.rs:42:10"); + } + + #[test] + fn test_imitate_human_edits_no_change_when_no_replacement() { + // Source and target patches that don't form a replacement pattern + let source = r#"--- a/test.rs ++++ b/test.rs +@@ -1,3 +1,4 @@ + fn main() { ++ println!("hello"); + } +"#; + let target = r#"--- a/test.rs ++++ b/test.rs +@@ -1,3 +1,4 @@ + fn main() { ++ println!("world"); + } +"#; + + let (new_src, new_tgt, cursor) = imitate_human_edits(source, target, 42); + + // Should return unchanged when not a replacement pattern + assert_eq!(new_src, source); + assert_eq!(new_tgt, target); + assert!(cursor.is_none()); + } + + #[test] + fn test_split_point_fraction() { + let commit = r#"// Change +--- a/test.rs ++++ b/test.rs +@@ -1,5 +1,10 @@ + fn main() { ++ line1(); ++ line2(); ++ line3(); ++ line4(); ++ line5(); + } +"#; + + // Split at 20% should give first edit in source + let result = generate_evaluation_example_from_ordered_commit( + commit, + "", + "hash", + Some(SplitPoint::Fraction(0.2)), + Some(1), + ); + + assert!(result.is_ok()); + let case = result.unwrap(); + + // Source should have some edits + let src_patch = Patch::parse_unified_diff(&case.edit_history[0]); + assert!(src_patch.stats().added > 0); + } + + #[test] + fn test_split_point_index() { + let commit = r#"// Change +--- a/test.rs ++++ b/test.rs +@@ -1,5 +1,10 @@ + fn main() { ++ line1(); ++ line2(); ++ line3(); ++ line4(); ++ line5(); + } +"#; + + // Split at index 2 should give first 2 edits in source + // With pure insertion handling, source gets 2 original + 1 partial = 3 additions + let result = generate_evaluation_example_from_ordered_commit( + commit, + "", + "hash", + Some(SplitPoint::Index(2)), + Some(1), + ); + + assert!(result.is_ok()); + let case = result.unwrap(); + + let src_patch = Patch::parse_unified_diff(&case.edit_history[0]); + // Pure insertion adds a partial line, so we expect 3 (2 original + 1 partial) + assert_eq!(src_patch.stats().added, 3); + } + + #[test] + fn test_cursor_excerpt_contains_marker() { + let commit = r#"//////////////////////////////////////////////////////////////////////////////// +// Add code +//////////////////////////////////////////////////////////////////////////////// +--- a/test.rs ++++ b/test.rs +@@ -1,3 +1,5 @@ + fn main() { ++ println!("hello"); ++ println!("world"); + } +"#; + + let result = generate_evaluation_example_from_ordered_commit( + commit, + "", + "hash", + Some(SplitPoint::Fraction(0.5)), + Some(42), + ) + .unwrap(); + + // Cursor excerpt should contain the cursor marker + assert!( + result.cursor_excerpt.contains("<|user_cursor|>"), + "Cursor excerpt should contain marker: {}", + result.cursor_excerpt + ); + } + + #[test] + fn test_evaluation_case_json_serialization() { + let case = EvaluationCase { + repository_url: "https://github.com/test/repo".to_string(), + commit: "abc123~1".to_string(), + edit_history: vec!["patch1".to_string()], + cursor_position: "file.rs:10:5".to_string(), + cursor_excerpt: "some code<|user_cursor|>".to_string(), + expected_hunks: vec!["hunk1".to_string()], + expected_patch: "patch".to_string(), + allowed_patch: "patch".to_string(), + expected_context_excerpts: vec![], + extra: serde_json::json!({}), + }; + + let json = serde_json::to_string(&case).unwrap(); + let deserialized: EvaluationCase = serde_json::from_str(&json).unwrap(); + + assert_eq!(case.repository_url, deserialized.repository_url); + assert_eq!(case.commit, deserialized.commit); + assert_eq!(case.cursor_position, deserialized.cursor_position); + } + + #[test] + fn test_empty_commit_returns_error() { + let commit = ""; + + let result = generate_evaluation_example_from_ordered_commit( + commit, + "", + "hash", + Some(SplitPoint::Fraction(0.5)), + Some(1), + ); + + assert!(result.is_err()); + } + + #[test] + fn test_header_filtering() { + let commit = r#"commit abc123 +Author: Test +Date: Today + + Message + +diff --git a/test.rs b/test.rs +index 123..456 789 +//////////////////////////////////////////////////////////////////////////////// +// First group +//////////////////////////////////////////////////////////////////////////////// +--- a/test.rs ++++ b/test.rs +@@ -1,3 +1,4 @@ + fn main() { ++ code(); + } +"#; + + let result = generate_evaluation_example_from_ordered_commit( + commit, + "", + "hash", + Some(SplitPoint::Index(1)), + Some(1), + ); + + assert!(result.is_ok()); + let case = result.unwrap(); + + // The edit history should contain the group header (// lines) + // but not the commit metadata + assert!(!case.edit_history[0].contains("Author:")); + assert!(!case.edit_history[0].contains("Date:")); + } + + #[test] + fn test_position_weight() { + // High weight positions (natural pause points) + assert_eq!(position_weight("foo(", 4), 10); // After '(' + assert_eq!(position_weight("a, b", 2), 10); // After ',' + assert_eq!(position_weight("x;", 2), 10); // After ';' + assert_eq!(position_weight("a: b", 2), 10); // After ':' + assert_eq!(position_weight("[", 1), 10); // After '[' + assert_eq!(position_weight("{", 1), 10); // After '{' + + // High weight for closing brackets + assert_eq!(position_weight("foo)", 4), 8); // After ')' + assert_eq!(position_weight("]", 1), 8); // After ']' + assert_eq!(position_weight("}", 1), 8); // After '}' + + // High weight at end of identifier + assert_eq!(position_weight("foo ", 3), 8); // End of 'foo' before space + assert_eq!(position_weight("bar(", 3), 8); // End of 'bar' before '(' + + // Medium weight for operators + assert_eq!(position_weight("a + b", 3), 5); // After '+' + assert_eq!(position_weight("x.", 2), 5); // After '.' + assert_eq!(position_weight("a=b", 2), 5); // After '=' + + // Medium weight for whitespace + assert_eq!(position_weight("a ", 2), 6); // After space + + // Low weight mid-identifier + assert_eq!(position_weight("foobar", 3), 1); // Mid-identifier 'foo|bar' + + // Edge cases + assert_eq!(position_weight("", 0), 1); // Empty string + assert_eq!(position_weight("a", 0), 1); // Position 0 + } + + #[test] + fn test_weighted_select() { + // Test that weighted selection returns correct indices + let weights = vec![1, 10, 1]; + + // With total weight 12, seed 0 should select index 0 + // seed 0 % 12 = 0, cumulative: 1 at idx 0, so returns 0 + assert_eq!(weighted_select(&weights, 0), 0); + + // seed 1 % 12 = 1, cumulative: 1 at idx 0 (1 < 1 is false), 11 at idx 1 (1 < 11 is true) + assert_eq!(weighted_select(&weights, 1), 1); + + // seed 10 % 12 = 10, cumulative: 1, 11 at idx 1 (10 < 11 is true) + assert_eq!(weighted_select(&weights, 10), 1); + + // seed 11 % 12 = 11, cumulative: 1, 11 at idx 1 (11 < 11 is false), 12 at idx 2 (11 < 12 is true) + assert_eq!(weighted_select(&weights, 11), 2); + + // Empty weights should return 0 + let empty: Vec = vec![]; + assert_eq!(weighted_select(&empty, 42), 0); + + // Single weight should always return index 0 + let single = vec![10]; + assert_eq!(weighted_select(&single, 0), 0); + assert_eq!(weighted_select(&single, 100), 0); + } + + #[test] + fn test_weighted_split_prefers_natural_boundaries() { + // Test that with different seeds, weighted selection tends to prefer + // positions after punctuation over mid-identifier positions + let text_with_punctuation = "foo(bar, baz)"; + let text_mid_identifier = "foobar"; + + // Position after '(' should have high weight + let weight_after_paren = position_weight(text_with_punctuation, 4); + // Position after ',' should have high weight + let weight_after_comma = position_weight(text_with_punctuation, 8); + // Position mid-identifier should have low weight + let weight_mid_ident = position_weight(text_mid_identifier, 3); + + assert!( + weight_after_paren > weight_mid_ident, + "After '(' ({}) should be weighted higher than mid-identifier ({})", + weight_after_paren, + weight_mid_ident + ); + assert!( + weight_after_comma > weight_mid_ident, + "After ',' ({}) should be weighted higher than mid-identifier ({})", + weight_after_comma, + weight_mid_ident + ); + } + + #[test] + fn test_imitate_human_edits_pure_insertion() { + // Source patch is empty (no edits yet) + // Target patch has a pure insertion (adding a new line) + let source = r#"--- a/test.rs ++++ b/test.rs +@@ -1,2 +1,2 @@ + fn main() { + } +"#; + let target = r#"--- a/test.rs ++++ b/test.rs +@@ -1,2 +1,3 @@ + fn main() { ++ println!("debug"); + } +"#; + + let (new_src, new_tgt, cursor) = imitate_human_edits(source, target, 42); + + // Should have transformed the patches + assert_ne!( + new_src, source, + "Source should be modified for pure insertion" + ); + assert_ne!( + new_tgt, target, + "Target should be modified for pure insertion" + ); + assert!(cursor.is_some(), "Cursor should be set"); + + // Source should now have a partial addition + let src_patch = Patch::parse_unified_diff(&new_src); + assert!( + src_patch.stats().added > 0, + "Source should have added lines" + ); + + // Target should have both a deletion (of partial) and addition (of full) + let tgt_patch = Patch::parse_unified_diff(&new_tgt); + assert!( + tgt_patch.stats().removed > 0, + "Target should have removed lines (partial)" + ); + assert!( + tgt_patch.stats().added > 0, + "Target should have added lines (full)" + ); + + // The cursor should be in test.rs + let cursor = cursor.unwrap(); + assert_eq!(cursor.file, "test.rs"); + } + + #[test] + fn test_imitate_human_edits_pure_insertion_empty_source() { + // Source patch has no hunks at all + let source = ""; + let target = r#"--- a/test.rs ++++ b/test.rs +@@ -1,2 +1,3 @@ + fn main() { ++ println!("hello"); + } +"#; + + let (new_src, _new_tgt, cursor) = imitate_human_edits(source, target, 123); + + // Should have created a source patch with partial insertion + assert!(!new_src.is_empty(), "Source should not be empty"); + assert!(cursor.is_some(), "Cursor should be set"); + + let src_patch = Patch::parse_unified_diff(&new_src); + assert!( + src_patch.stats().added > 0, + "Source should have added lines" + ); + } + + #[test] + fn test_imitate_human_edits_pure_insertion_intermediate_content() { + // Verify the actual intermediate content is a realistic partial typing state + let source = ""; + let target = r#"--- a/test.rs ++++ b/test.rs +@@ -1,2 +1,3 @@ + fn main() { ++ println!("hello world"); + } +"#; + + // Test with multiple seeds to see different split points + let mut found_partial = false; + for seed in 1..=50 { + let (new_src, new_tgt, cursor) = imitate_human_edits(source, target, seed); + + if cursor.is_some() { + let src_patch = Patch::parse_unified_diff(&new_src); + let tgt_patch = Patch::parse_unified_diff(&new_tgt); + + // Find the added line in source + for hunk in &src_patch.hunks { + for line in &hunk.lines { + if let PatchLine::Addition(content) = line { + // The partial line should be a prefix of the full line + let full_line = " println!(\"hello world\");"; + if content != full_line && full_line.starts_with(content) { + found_partial = true; + + // Verify target has the partial as deletion + let mut has_deletion = false; + for tgt_hunk in &tgt_patch.hunks { + for tgt_line in &tgt_hunk.lines { + if let PatchLine::Deletion(del_content) = tgt_line { + if del_content == content { + has_deletion = true; + } + } + } + } + assert!( + has_deletion, + "Target should have deletion of partial line" + ); + } + } + } + } + } + } + + assert!( + found_partial, + "At least one seed should produce a partial intermediate state" + ); + } +} diff --git a/typos.toml b/typos.toml index 8e42bd674a64d8adc1e684df181c8e4ce67988e9..37a7a37a43e891661ec885d1be21ea2f3b364d67 100644 --- a/typos.toml +++ b/typos.toml @@ -57,6 +57,8 @@ extend-exclude = [ "crates/multi_buffer/src/multi_buffer_tests.rs", # Macos apis "crates/gpui/src/platform/mac/dispatcher.rs", + # Tests contain partially incomplete words (by design) + "crates/edit_prediction_cli/src/split_commit.rs", ] [default]