1//! `ep split-commit` implementation.
2//!
3//! This command generates a single evaluation example JSON object from a
4//! chronologically-ordered unified diff (a "commit").
5//!
6//! TODO: Port Python code to generate chronologically-ordered commits
7use crate::reorder_patch::{Patch, PatchLine, extract_edits, locate_edited_line};
8use anyhow::{Context as _, Result};
9use clap::Args;
10use rand::Rng;
11use rand::SeedableRng;
12use serde::{Deserialize, Serialize};
13use similar::{DiffTag, TextDiff};
14use std::collections::BTreeSet;
15use std::fs;
16use std::io::{self, Read};
17
18/// `ep split-commit` CLI args.
19#[derive(Debug, Args, Clone)]
20pub struct SplitCommitArgs {
21 /// Path to the commit file (use "-" for stdin)
22 #[arg(long, short = 'c')]
23 pub commit: String,
24
25 /// Repository URL
26 #[arg(long, short = 'r', default_value_t = String::new())]
27 pub repository_url: String,
28
29 /// Commit hash
30 #[arg(long, default_value_t = String::new())]
31 pub commit_hash: String,
32
33 /// Split point (float 0.0-1.0 for fraction, or integer for index)
34 #[arg(long, short = 's')]
35 pub split_point: Option<String>,
36
37 /// Random seed for reproducibility
38 #[arg(long)]
39 pub seed: Option<u64>,
40
41 /// Pretty-print JSON output
42 #[arg(long, short = 'p')]
43 pub pretty: bool,
44}
45
46/// Cursor position in a file.
47#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
48pub struct CursorPosition {
49 pub file: String,
50 pub line: usize,
51 pub column: usize,
52}
53
54impl std::fmt::Display for CursorPosition {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 write!(f, "{}:{}:{}", self.file, self.line, self.column)
57 }
58}
59
60/// Represents a split commit with source and target patches.
61#[derive(Debug, Clone)]
62pub struct SplitCommit {
63 pub source_patch: String,
64 pub target_patch: String,
65}
66
67/// The evaluation case structure that will be serialized to JSON.
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct EvaluationCase {
70 pub repository_url: String,
71 pub commit: String,
72 pub edit_history: Vec<String>,
73 pub cursor_position: String,
74 pub cursor_excerpt: String,
75 pub expected_hunks: Vec<String>,
76 pub expected_patch: String,
77 pub allowed_patch: String,
78 pub expected_context_excerpts: Vec<String>,
79 pub extra: serde_json::Value,
80}
81
82/// Split point specification for evaluation generation.
83#[derive(Debug, Clone)]
84pub enum SplitPoint {
85 /// Fraction of total edits (0.0 to 1.0)
86 Fraction(f64),
87 /// Absolute index
88 Index(usize),
89}
90
91fn parse_split_point(value: &str) -> Option<SplitPoint> {
92 if value.contains('.') {
93 value.parse::<f64>().ok().map(SplitPoint::Fraction)
94 } else {
95 value.parse::<usize>().ok().map(SplitPoint::Index)
96 }
97}
98
99/// Entry point for the `ep split-commit` subcommand.
100///
101/// This runs synchronously and prints a single JSON object to stdout.
102pub fn run_split_commit(args: &SplitCommitArgs) -> Result<()> {
103 let commit = if args.commit == "-" {
104 let mut content = String::new();
105 io::stdin()
106 .read_to_string(&mut content)
107 .context("failed to read commit diff from stdin")?;
108 content
109 } else {
110 fs::read_to_string(&args.commit)
111 .with_context(|| format!("failed to read commit diff from {}", args.commit))?
112 };
113
114 let split_point = args.split_point.as_deref().and_then(parse_split_point);
115
116 let case = generate_evaluation_example_from_ordered_commit(
117 &commit,
118 &args.repository_url,
119 &args.commit_hash,
120 split_point,
121 args.seed,
122 )
123 .context("failed to generate evaluation example")?;
124
125 let json = if args.pretty {
126 serde_json::to_string_pretty(&case)
127 } else {
128 serde_json::to_string(&case)
129 }
130 .context("failed to serialize evaluation case as JSON")?;
131
132 println!("{json}");
133 Ok(())
134}
135
136/// Main function to generate an evaluation example from an ordered commit.
137///
138/// # Arguments
139/// * `commit` - Chronologically ordered unified diff of the commit
140/// * `repository_url` - URL of the repository
141/// * `commit_hash` - Hash of the commit
142/// * `split_point` - Point at which the commit will be split (None for random)
143/// * `seed` - Optional seed for randomness
144pub fn generate_evaluation_example_from_ordered_commit(
145 commit: &str,
146 repository_url: &str,
147 commit_hash: &str,
148 split_point: Option<SplitPoint>,
149 seed: Option<u64>,
150) -> Result<EvaluationCase> {
151 let mut rng: Box<dyn rand::RngCore> = match seed {
152 Some(seed) => Box::new(rand::rngs::StdRng::seed_from_u64(seed)),
153 None => Box::new(rand::rngs::ThreadRng::default()),
154 };
155
156 // Parse and normalize the commit
157 let mut patch = Patch::parse_unified_diff(commit);
158
159 // Filter header to only keep lines starting with "//"
160 let header_lines: Vec<&str> = patch
161 .header
162 .lines()
163 .filter(|line| line.starts_with("//"))
164 .collect();
165 patch.header = if header_lines.is_empty() {
166 String::new()
167 } else {
168 header_lines.join("\n") + "\n"
169 };
170 let commit_normalized = patch.to_string();
171
172 // Compute the split point
173 let stats = patch.stats();
174 let num_edits = stats.added + stats.removed;
175
176 anyhow::ensure!(num_edits != 0, "no edits found in commit");
177
178 let split = match split_point {
179 None => rng.random_range(1..=num_edits),
180 Some(SplitPoint::Fraction(f)) => {
181 let v = (f * num_edits as f64).floor() as usize;
182 v.min(num_edits)
183 }
184 Some(SplitPoint::Index(i)) => i.min(num_edits),
185 };
186
187 // Split the commit into source and target patches
188 let (prefix, suffix) = split_ordered_commit(&commit_normalized, split);
189
190 let mut split_commit = SplitCommit {
191 source_patch: prefix,
192 target_patch: suffix,
193 };
194
195 // Imitate human edits
196 let human_edit_seed = rng.random_range(1..=10000u64);
197 let (src_patch, tgt_patch, cursor_opt) = imitate_human_edits(
198 &split_commit.source_patch,
199 &split_commit.target_patch,
200 human_edit_seed,
201 );
202 split_commit.source_patch = src_patch;
203 split_commit.target_patch = tgt_patch;
204
205 // Sample cursor position
206 let cursor = match cursor_opt {
207 Some(c) => c,
208 None => sample_cursor_position(&patch, &split_commit)
209 .context("failed to sample cursor position")?,
210 };
211
212 // Get cursor excerpt
213 let cursor_excerpt = get_cursor_excerpt(
214 &cursor,
215 &split_commit.source_patch,
216 &split_commit.target_patch,
217 )
218 .context("failed to generate cursor excerpt")?;
219
220 // Handle edge case where split_point == 0
221 if split == 0 {
222 split_commit.target_patch = String::new();
223 }
224
225 Ok(EvaluationCase {
226 repository_url: repository_url.to_string(),
227 commit: format!("{}~1", commit_hash),
228 edit_history: vec![split_commit.source_patch.clone()],
229 cursor_position: cursor.to_string(),
230 cursor_excerpt,
231 expected_hunks: vec![split_commit.target_patch.clone()],
232 expected_patch: split_commit.target_patch.clone(),
233 allowed_patch: split_commit.target_patch,
234 expected_context_excerpts: vec![],
235 extra: serde_json::json!({}),
236 })
237}
238
239/// Split an ordered commit into source and target commits.
240///
241/// # Arguments
242/// * `commit` - Ordered commit string
243/// * `split_pos` - Position to split the commit (number of edited lines)
244///
245/// # Returns
246/// A tuple of (source_diff, target_diff)
247pub fn split_ordered_commit(commit: &str, split_pos: usize) -> (String, String) {
248 let patch = Patch::parse_unified_diff(commit);
249 let source_edits: BTreeSet<usize> = (0..split_pos).collect();
250 let (source, target) = extract_edits(&patch, &source_edits);
251
252 let mut source_str = source.to_string();
253 let target_str = target.to_string();
254
255 // Strip last group header from the source (lines starting with "//" at the end)
256 let source_lines: Vec<&str> = source_str.lines().collect();
257 let mut end_idx = source_lines.len();
258 for i in (0..source_lines.len()).rev() {
259 if source_lines[i].starts_with("//") {
260 end_idx = i;
261 } else {
262 break;
263 }
264 }
265 if end_idx < source_lines.len() {
266 source_str = source_lines[..end_idx].join("\n");
267 if !source_str.is_empty() {
268 source_str.push('\n');
269 }
270 }
271
272 (source_str, target_str)
273}
274
275/// Tokenize text into words and non-word characters.
276fn tokenize(text: &str) -> Vec<String> {
277 let mut tokens = Vec::new();
278 let mut current = String::new();
279
280 for ch in text.chars() {
281 if ch.is_alphanumeric() {
282 current.push(ch);
283 } else if ch == '_' {
284 // Include underscore with the current word, then flush
285 current.push(ch);
286 if !current.is_empty() {
287 tokens.push(std::mem::take(&mut current));
288 }
289 } else {
290 // Punctuation or whitespace - flush current word first
291 if !current.is_empty() {
292 tokens.push(std::mem::take(&mut current));
293 }
294 // Each punctuation/whitespace is its own token
295 tokens.push(ch.to_string());
296 }
297 }
298
299 if !current.is_empty() {
300 tokens.push(current);
301 }
302
303 tokens
304}
305
306/// Calculate the weight for a split position based on the character at that position.
307///
308/// Higher weights indicate more natural pause points (e.g., after punctuation,
309/// at identifier boundaries). Lower weights indicate less natural points
310/// (e.g., mid-identifier).
311fn position_weight(text: &str, pos: usize) -> u32 {
312 if pos == 0 || pos > text.len() {
313 return 1;
314 }
315
316 let chars: Vec<char> = text.chars().collect();
317 if pos > chars.len() {
318 return 1;
319 }
320
321 // Get the character just before this position (what we just "typed")
322 let prev_char = chars[pos - 1];
323
324 // High weight: natural pause points (end of statement/argument, opening brackets)
325 if matches!(prev_char, ',' | ';' | ':' | '(' | '[' | '{') {
326 return 10;
327 }
328
329 // High weight: closing brackets (finished a group)
330 if matches!(prev_char, ')' | ']' | '}') {
331 return 8;
332 }
333
334 // Medium weight: operators and method chains
335 if matches!(
336 prev_char,
337 '.' | '+' | '-' | '*' | '/' | '=' | '<' | '>' | '&' | '|' | '!'
338 ) {
339 return 5;
340 }
341
342 // Check if we're at the end of an identifier (word char followed by non-word char)
343 let is_prev_word_char = prev_char.is_alphanumeric() || prev_char == '_';
344 let is_next_word_char =
345 pos < chars.len() && (chars[pos].is_alphanumeric() || chars[pos] == '_');
346
347 if is_prev_word_char && !is_next_word_char {
348 // End of identifier - high weight
349 return 8;
350 }
351
352 // Whitespace is a natural pause
353 if prev_char.is_whitespace() {
354 return 6;
355 }
356
357 // Mid-identifier: low weight (rare autocomplete scenarios)
358 if is_prev_word_char && is_next_word_char {
359 return 1;
360 }
361
362 // Default medium-low weight
363 3
364}
365
366/// Select a weighted random index from a list of weights.
367///
368/// Returns an index based on the weights, using the provided seed for
369/// deterministic selection.
370fn weighted_select(weights: &[u32], seed: u64) -> usize {
371 if weights.is_empty() {
372 return 0;
373 }
374
375 let total_weight: u64 = weights.iter().map(|&w| w as u64).sum();
376 if total_weight == 0 {
377 // Fallback to uniform selection if all weights are zero
378 return seed as usize % weights.len();
379 }
380
381 // Use seed to select a value in [0, total_weight)
382 let target = seed % total_weight;
383 let mut cumulative: u64 = 0;
384
385 for (idx, &weight) in weights.iter().enumerate() {
386 cumulative += weight as u64;
387 if target < cumulative {
388 return idx;
389 }
390 }
391
392 // Fallback to last index
393 weights.len() - 1
394}
395
396/// Calculate similarity ratio between two strings (0-100).
397fn fuzzy_ratio(s1: &str, s2: &str) -> u32 {
398 if s1.is_empty() && s2.is_empty() {
399 return 100;
400 }
401 if s1.is_empty() || s2.is_empty() {
402 return 0;
403 }
404
405 let diff = TextDiff::from_chars(s1, s2);
406 let matching: usize = diff
407 .ops()
408 .iter()
409 .filter_map(|op| {
410 if matches!(op.tag(), DiffTag::Equal) {
411 Some(op.new_range().len())
412 } else {
413 None
414 }
415 })
416 .sum();
417
418 let total = s1.len() + s2.len();
419 ((2 * matching * 100) / total) as u32
420}
421
422/// Imitate human edits by introducing partial line edits.
423///
424/// This function simulates how a human might incrementally type code,
425/// rather than making complete line replacements.
426pub fn imitate_human_edits(
427 source_patch: &str,
428 target_patch: &str,
429 seed: u64,
430) -> (String, String, Option<CursorPosition>) {
431 let no_change = (source_patch.to_string(), target_patch.to_string(), None);
432
433 let src_patch = Patch::parse_unified_diff(source_patch);
434 let tgt_patch = Patch::parse_unified_diff(target_patch);
435
436 if tgt_patch.hunks.is_empty() {
437 return no_change;
438 }
439
440 // Try to locate the first edit in target
441 let tgt_edit_loc = match locate_edited_line(&tgt_patch, 0) {
442 Some(loc) => loc,
443 None => return no_change,
444 };
445
446 let tgt_is_addition = matches!(tgt_edit_loc.patch_line, PatchLine::Addition(_));
447 if !tgt_is_addition {
448 return no_change;
449 }
450
451 let tgt_line = match &tgt_edit_loc.patch_line {
452 PatchLine::Addition(s) => s.clone(),
453 _ => return no_change,
454 };
455
456 // Try to locate the last edit in source
457 let src_edit_loc = locate_edited_line(&src_patch, -1);
458
459 // Check if source has ANY edit at the same line as target's first edit
460 // We need to iterate through all edits to check this
461 let src_has_edit_at_target_line = {
462 let mut found = false;
463 let mut idx = 0isize;
464 while let Some(loc) = locate_edited_line(&src_patch, idx) {
465 if loc.filename == tgt_edit_loc.filename
466 && loc.target_line_number == tgt_edit_loc.target_line_number
467 {
468 found = true;
469 break;
470 }
471 idx += 1;
472 }
473 found
474 };
475
476 // Check if this is a replacement (deletion followed by insertion on the same line)
477 // or a pure insertion (no corresponding deletion in source)
478 let is_replacement = src_edit_loc.as_ref().map_or(false, |loc| {
479 matches!(loc.patch_line, PatchLine::Deletion(_))
480 && loc.filename == tgt_edit_loc.filename
481 && loc.target_line_number == tgt_edit_loc.target_line_number
482 });
483
484 // If source has an edit at the same line but it's not a replacement (i.e., it's an addition),
485 // we shouldn't process this as a pure insertion either
486 if !is_replacement && src_has_edit_at_target_line {
487 return no_change;
488 }
489
490 let src_line = if is_replacement {
491 match &src_edit_loc.as_ref().unwrap().patch_line {
492 PatchLine::Deletion(s) => s.clone(),
493 _ => return no_change,
494 }
495 } else {
496 // Pure insertion: source line is empty
497 String::new()
498 };
499
500 // Don't process if source and target are the same
501 if src_line == tgt_line {
502 return no_change;
503 }
504
505 // Tokenize both lines
506 let src_tokens = tokenize(&src_line);
507 let tgt_tokens = tokenize(&tgt_line);
508
509 // Convert to slices for similar
510 let src_refs: Vec<&str> = src_tokens.iter().map(|s| s.as_str()).collect();
511 let tgt_refs: Vec<&str> = tgt_tokens.iter().map(|s| s.as_str()).collect();
512
513 // Use similar to get diff operations
514 let diff = TextDiff::from_slices(&src_refs, &tgt_refs);
515
516 // Build weights for each possible split position
517 let mut position_weights: Vec<u32> = Vec::new();
518
519 // Simulate the edit process to collect weights for all possible split positions
520 {
521 let mut current_text = String::new();
522
523 for op in diff.ops() {
524 match op.tag() {
525 DiffTag::Equal => {
526 for i in op.old_range() {
527 current_text.push_str(&src_tokens[i]);
528 }
529 }
530 DiffTag::Replace => {
531 let ins: String = op.new_range().map(|i| tgt_tokens[i].as_str()).collect();
532 let del: String = op.old_range().map(|i| src_tokens[i].as_str()).collect();
533
534 // For insertion part
535 for ch in ins.chars() {
536 current_text.push(ch);
537 let weight = position_weight(¤t_text, current_text.len());
538 position_weights.push(weight);
539 }
540
541 // For deletion part (we're "untyping" from source)
542 for _ in del.chars() {
543 // Weight deletions lower as they represent removing text
544 position_weights.push(2);
545 }
546 }
547 DiffTag::Insert => {
548 let ins: String = op.new_range().map(|i| tgt_tokens[i].as_str()).collect();
549 for ch in ins.chars() {
550 current_text.push(ch);
551 let weight = position_weight(¤t_text, current_text.len());
552 position_weights.push(weight);
553 }
554 }
555 DiffTag::Delete => {
556 let del: String = op.old_range().map(|i| src_tokens[i].as_str()).collect();
557 for _ in del.chars() {
558 // Weight deletions lower
559 position_weights.push(2);
560 }
561 }
562 }
563 }
564 }
565
566 // Use weighted selection to choose split index
567 if position_weights.is_empty() {
568 return no_change;
569 }
570 let split_index = weighted_select(&position_weights, seed);
571
572 let mut edit_index = 0usize;
573 let mut new_src = String::new();
574 let mut split_found = false;
575 let mut last_old_end = 0usize;
576
577 for op in diff.ops() {
578 match op.tag() {
579 DiffTag::Equal => {
580 for i in op.old_range() {
581 new_src.push_str(&src_tokens[i]);
582 }
583 last_old_end = op.old_range().end;
584 }
585 DiffTag::Replace => {
586 // Handle replace as delete + insert
587 let del: String = op.old_range().map(|i| src_tokens[i].as_str()).collect();
588 let ins: String = op.new_range().map(|i| tgt_tokens[i].as_str()).collect();
589 let repl_len = del.len() + ins.len();
590 if edit_index + repl_len >= split_index {
591 // Split within this replace operation
592 let offset = split_index - edit_index;
593 if offset < ins.len() {
594 new_src.push_str(&ins[..offset]);
595 } else {
596 new_src.push_str(&ins);
597 let del_offset = offset - ins.len();
598 new_src.push_str(&del[..del_offset.min(del.len())]);
599 }
600 split_found = true;
601 last_old_end = op.old_range().end;
602 break;
603 } else {
604 edit_index += repl_len;
605 new_src.push_str(&ins);
606 last_old_end = op.old_range().end;
607 }
608 }
609 DiffTag::Insert => {
610 let repl: String = op.new_range().map(|i| tgt_tokens[i].as_str()).collect();
611 if edit_index + repl.len() >= split_index {
612 let offset = split_index - edit_index;
613 new_src.push_str(&repl[..offset]);
614 split_found = true;
615 break;
616 } else {
617 edit_index += repl.len();
618 new_src.push_str(&repl);
619 }
620 }
621 DiffTag::Delete => {
622 let repl: String = op.old_range().map(|i| src_tokens[i].as_str()).collect();
623 if edit_index + repl.len() >= split_index {
624 let offset = split_index - edit_index;
625 new_src.push_str(&repl[..offset]);
626 split_found = true;
627 last_old_end = op.old_range().start + offset.min(op.old_range().len());
628 break;
629 } else {
630 edit_index += repl.len();
631 new_src.push_str(&repl);
632 last_old_end = op.old_range().end;
633 }
634 }
635 }
636 }
637
638 if !split_found {
639 return no_change;
640 }
641
642 // Calculate cursor position
643 let cursor = CursorPosition {
644 file: tgt_edit_loc.filename.clone(),
645 line: if is_replacement {
646 src_edit_loc.as_ref().unwrap().source_line_number
647 } else {
648 tgt_edit_loc.target_line_number
649 },
650 column: new_src.len() + 1,
651 };
652
653 // Add remainder of source if similar enough to target remainder
654 let remainder_src: String = (last_old_end..src_tokens.len())
655 .map(|i| src_tokens[i].as_str())
656 .collect();
657 let remainder_tgt: String = (last_old_end..tgt_tokens.len())
658 .filter_map(|i| tgt_tokens.get(i).map(|s| s.as_str()))
659 .collect();
660
661 let ratio = fuzzy_ratio(&remainder_src, &remainder_tgt);
662 if ratio > 35 {
663 new_src.push_str(&remainder_src);
664 }
665
666 if new_src.trim().is_empty() {
667 return no_change;
668 }
669
670 if new_src == src_line {
671 return no_change;
672 }
673
674 // Build new source patch with the intermediate line
675 let mut new_src_patch = src_patch;
676 if is_replacement {
677 // For replacements, insert after the deletion line
678 let src_loc = src_edit_loc.as_ref().unwrap();
679 if let Some(hunk) = new_src_patch.hunks.get_mut(src_loc.hunk_index) {
680 hunk.lines.insert(
681 src_loc.line_index_within_hunk + 1,
682 PatchLine::Addition(new_src.clone()),
683 );
684 hunk.new_count += 1;
685 }
686 } else {
687 // For pure insertions, we need to add or modify a hunk
688 if let Some(hunk) = new_src_patch.hunks.get_mut(tgt_edit_loc.hunk_index) {
689 // Insert the partial line at the same position as target
690 hunk.lines.insert(
691 tgt_edit_loc.line_index_within_hunk,
692 PatchLine::Addition(new_src.clone()),
693 );
694 hunk.new_count += 1;
695 } else if new_src_patch.hunks.is_empty() {
696 // Source patch is empty, create a new hunk based on target
697 if let Some(tgt_hunk) = tgt_patch.hunks.get(tgt_edit_loc.hunk_index) {
698 let mut new_hunk = tgt_hunk.clone();
699 // Replace the full addition with the partial one
700 new_hunk.lines.clear();
701 for (i, line) in tgt_hunk.lines.iter().enumerate() {
702 if i == tgt_edit_loc.line_index_within_hunk {
703 new_hunk.lines.push(PatchLine::Addition(new_src.clone()));
704 } else {
705 match line {
706 PatchLine::Addition(_) => {
707 // Skip other additions from target
708 }
709 _ => new_hunk.lines.push(line.clone()),
710 }
711 }
712 }
713 new_hunk.new_count = new_hunk.old_count + 1;
714 new_src_patch.hunks.push(new_hunk);
715 // Copy header from target if source doesn't have one
716 if new_src_patch.header.is_empty() {
717 new_src_patch.header = tgt_patch.header.clone();
718 }
719 }
720 }
721 }
722
723 // Build new target patch with the intermediate line as deletion
724 let mut new_tgt_patch = tgt_patch;
725 if let Some(hunk) = new_tgt_patch.hunks.get_mut(tgt_edit_loc.hunk_index) {
726 hunk.lines.insert(
727 tgt_edit_loc.line_index_within_hunk,
728 PatchLine::Deletion(new_src),
729 );
730 hunk.old_count += 1;
731 }
732
733 (
734 new_src_patch.to_string(),
735 new_tgt_patch.to_string(),
736 Some(cursor),
737 )
738}
739
740/// Locate the end of the last edit in a patch.
741fn locate_end_of_last_edit(patch: &Patch) -> Option<CursorPosition> {
742 let loc = locate_edited_line(patch, -1)?;
743
744 let (line, col) = match &loc.patch_line {
745 PatchLine::Addition(content) => (loc.target_line_number, content.len()),
746 PatchLine::Deletion(_) => (loc.target_line_number, 1),
747 _ => return None,
748 };
749
750 Some(CursorPosition {
751 file: loc.filename,
752 line,
753 column: col,
754 })
755}
756
757/// Locate the beginning of the first edit in a patch.
758fn locate_beginning_of_first_edit(patch: &Patch) -> Option<CursorPosition> {
759 let loc = locate_edited_line(patch, 0)?;
760
761 let hunk = patch.hunks.get(loc.hunk_index)?;
762 let column = if loc.line_index_within_hunk > 0 {
763 if let Some(prev_line) = hunk.lines.get(loc.line_index_within_hunk - 1) {
764 let content = match prev_line {
765 PatchLine::Context(s) | PatchLine::Addition(s) | PatchLine::Deletion(s) => s,
766 _ => return None,
767 };
768 content.len().max(1) - 1
769 } else {
770 0
771 }
772 } else {
773 0
774 };
775
776 let line = loc.target_line_number.saturating_sub(1).max(1);
777
778 Some(CursorPosition {
779 file: loc.filename,
780 line,
781 column,
782 })
783}
784
785/// Sample cursor position according to the following rules:
786/// 1. 50% chance of cursor being at the end of the source patch
787/// 2. 50% chance of cursor being at the beginning of the target patch
788pub fn sample_cursor_position(patch: &Patch, split_commit: &SplitCommit) -> Option<CursorPosition> {
789 // Try end of history first
790 let src_patch = Patch::parse_unified_diff(&split_commit.source_patch);
791 if let Some(cursor) = locate_end_of_last_edit(&src_patch) {
792 return Some(cursor);
793 }
794
795 // Try beginning of target
796 let tgt_patch = Patch::parse_unified_diff(&split_commit.target_patch);
797 if let Some(cursor) = locate_beginning_of_first_edit(&tgt_patch) {
798 return Some(cursor);
799 }
800
801 // Fallback: use the original patch
802 locate_end_of_last_edit(patch)
803}
804
805/// Get cursor excerpt from the patches.
806///
807/// This extracts the lines around the cursor position with a cursor marker.
808pub fn get_cursor_excerpt(
809 cursor: &CursorPosition,
810 source_patch: &str,
811 target_patch: &str,
812) -> Option<String> {
813 let mut excerpt_lines: Vec<String> = Vec::new();
814 let mut excerpt_first_line: usize = 0;
815
816 // Search in the last hunk of source patch
817 let src = Patch::parse_unified_diff(source_patch);
818 if let Some(loc) = locate_edited_line(&src, -1) {
819 if loc.filename == cursor.file && loc.target_line_number == cursor.line {
820 if let Some(hunk) = src.hunks.get(loc.hunk_index) {
821 excerpt_first_line = hunk.new_start as usize;
822 for line in &hunk.lines {
823 match line {
824 PatchLine::Addition(s) | PatchLine::Context(s) => {
825 excerpt_lines.push(s.clone());
826 }
827 _ => {}
828 }
829 }
830 }
831 }
832 }
833
834 // Search in target patch if not found
835 if excerpt_lines.is_empty() {
836 let tgt = Patch::parse_unified_diff(target_patch);
837 if let Some(loc) = locate_edited_line(&tgt, 0) {
838 if loc.filename == cursor.file {
839 if let Some(hunk) = tgt.hunks.get(loc.hunk_index) {
840 excerpt_first_line = hunk.new_start as usize;
841 for line in &hunk.lines {
842 match line {
843 PatchLine::Deletion(s) | PatchLine::Context(s) => {
844 excerpt_lines.push(s.clone());
845 }
846 _ => {}
847 }
848 }
849 }
850 }
851 }
852 }
853
854 if excerpt_lines.is_empty() {
855 return None;
856 }
857
858 // Add cursor marker
859 for (i, line) in excerpt_lines.iter_mut().enumerate() {
860 let line_num = excerpt_first_line + i;
861 if line_num == cursor.line {
862 let col = cursor.column.min(line.len());
863 let (before, after) = line.split_at(col);
864 *line = format!("{}<|user_cursor|>{}", before, after);
865 break;
866 }
867 }
868
869 Some(excerpt_lines.join("\n"))
870}
871
872#[cfg(test)]
873mod tests {
874 use super::*;
875
876 #[test]
877 fn test_tokenize() {
878 let tokens = tokenize("hello world");
879 assert_eq!(tokens, vec!["hello", " ", "world"]);
880
881 let tokens = tokenize("foo_bar123 + baz");
882 assert_eq!(tokens, vec!["foo_", "bar123", " ", "+", " ", "baz"]);
883
884 let tokens = tokenize("print(\"hello\")");
885 assert_eq!(tokens, vec!["print", "(", "\"", "hello", "\"", ")"]);
886
887 let tokens = tokenize("hello_world");
888 assert_eq!(tokens, vec!["hello_", "world"]);
889
890 let tokens = tokenize("fn();");
891 assert_eq!(tokens, vec!["fn", "(", ")", ";"]);
892 }
893
894 #[test]
895 fn test_fuzzy_ratio() {
896 assert_eq!(fuzzy_ratio("hello", "hello"), 100);
897 assert_eq!(fuzzy_ratio("", ""), 100);
898 assert!(fuzzy_ratio("hello", "world") < 50);
899 assert!(fuzzy_ratio("hello world", "hello worl") > 80);
900 }
901
902 #[test]
903 fn test_split_ordered_commit() {
904 let commit = r#"// First change
905--- a/test.rs
906+++ b/test.rs
907@@ -1,3 +1,4 @@
908 fn main() {
909+ println!("hello");
910+ println!("world");
911 }
912"#;
913 let patch = Patch::parse_unified_diff(commit);
914 let stats = patch.stats();
915 assert_eq!(stats.added, 2);
916
917 let (source, target) = split_ordered_commit(commit, 1);
918
919 // Source should have 1 addition
920 let src_patch = Patch::parse_unified_diff(&source);
921 assert_eq!(src_patch.stats().added, 1);
922
923 // Target should have 1 addition
924 let tgt_patch = Patch::parse_unified_diff(&target);
925 assert_eq!(tgt_patch.stats().added, 1);
926 }
927
928 #[test]
929 fn test_split_ordered_commit_with_deletions() {
930 let commit = r#"// Change
931--- a/test.rs
932+++ b/test.rs
933@@ -1,3 +1,3 @@
934 fn main() {
935- println!("old");
936+ println!("new");
937 }
938"#;
939 let patch = Patch::parse_unified_diff(commit);
940 let stats = patch.stats();
941 assert_eq!(stats.added, 1);
942 assert_eq!(stats.removed, 1);
943
944 // Split at position 1 (after the deletion)
945 let (source, target) = split_ordered_commit(commit, 1);
946
947 let src_patch = Patch::parse_unified_diff(&source);
948 let tgt_patch = Patch::parse_unified_diff(&target);
949
950 // Source should have the deletion
951 assert_eq!(src_patch.stats().removed, 1);
952 // Target should have the addition
953 assert_eq!(tgt_patch.stats().added, 1);
954 }
955
956 #[test]
957 fn test_generate_evaluation_example() {
958 let commit = r#"commit abc123
959Author: Test <test@example.com>
960Date: Mon Jan 1 00:00:00 2024
961
962 Test commit
963
964////////////////////////////////////////////////////////////////////////////////
965// Add greeting
966////////////////////////////////////////////////////////////////////////////////
967--- a/test.rs
968+++ b/test.rs
969@@ -1,3 +1,5 @@
970 fn main() {
971+ println!("hello");
972+ println!("world");
973 }
974"#;
975
976 let result = generate_evaluation_example_from_ordered_commit(
977 commit,
978 "https://github.com/test/repo",
979 "abc123",
980 Some(SplitPoint::Fraction(0.5)),
981 Some(42),
982 );
983
984 assert!(result.is_ok());
985 let case = result.unwrap();
986 assert_eq!(case.repository_url, "https://github.com/test/repo");
987 assert_eq!(case.commit, "abc123~1");
988 assert!(!case.edit_history.is_empty());
989 }
990
991 #[test]
992 fn test_generate_evaluation_example_reproducible() {
993 let commit = r#"////////////////////////////////////////////////////////////////////////////////
994// Add greeting
995////////////////////////////////////////////////////////////////////////////////
996--- a/test.rs
997+++ b/test.rs
998@@ -1,3 +1,5 @@
999 fn main() {
1000+ println!("hello");
1001+ println!("world");
1002 }
1003"#;
1004
1005 // Run twice with the same seed
1006 let result1 = generate_evaluation_example_from_ordered_commit(
1007 commit,
1008 "https://github.com/test/repo",
1009 "abc123",
1010 Some(SplitPoint::Fraction(0.5)),
1011 Some(12345),
1012 )
1013 .unwrap();
1014
1015 let result2 = generate_evaluation_example_from_ordered_commit(
1016 commit,
1017 "https://github.com/test/repo",
1018 "abc123",
1019 Some(SplitPoint::Fraction(0.5)),
1020 Some(12345),
1021 )
1022 .unwrap();
1023
1024 // Results should be identical
1025 assert_eq!(result1.edit_history, result2.edit_history);
1026 assert_eq!(result1.expected_patch, result2.expected_patch);
1027 assert_eq!(result1.cursor_position, result2.cursor_position);
1028 }
1029
1030 #[test]
1031 fn test_cursor_position_display() {
1032 let cursor = CursorPosition {
1033 file: "src/main.rs".to_string(),
1034 line: 42,
1035 column: 10,
1036 };
1037 assert_eq!(cursor.to_string(), "src/main.rs:42:10");
1038 }
1039
1040 #[test]
1041 fn test_imitate_human_edits_no_change_when_no_replacement() {
1042 // Source and target patches that don't form a replacement pattern
1043 let source = r#"--- a/test.rs
1044+++ b/test.rs
1045@@ -1,3 +1,4 @@
1046 fn main() {
1047+ println!("hello");
1048 }
1049"#;
1050 let target = r#"--- a/test.rs
1051+++ b/test.rs
1052@@ -1,3 +1,4 @@
1053 fn main() {
1054+ println!("world");
1055 }
1056"#;
1057
1058 let (new_src, new_tgt, cursor) = imitate_human_edits(source, target, 42);
1059
1060 // Should return unchanged when not a replacement pattern
1061 assert_eq!(new_src, source);
1062 assert_eq!(new_tgt, target);
1063 assert!(cursor.is_none());
1064 }
1065
1066 #[test]
1067 fn test_split_point_fraction() {
1068 let commit = r#"// Change
1069--- a/test.rs
1070+++ b/test.rs
1071@@ -1,5 +1,10 @@
1072 fn main() {
1073+ line1();
1074+ line2();
1075+ line3();
1076+ line4();
1077+ line5();
1078 }
1079"#;
1080
1081 // Split at 20% should give first edit in source
1082 let result = generate_evaluation_example_from_ordered_commit(
1083 commit,
1084 "",
1085 "hash",
1086 Some(SplitPoint::Fraction(0.2)),
1087 Some(1),
1088 );
1089
1090 assert!(result.is_ok());
1091 let case = result.unwrap();
1092
1093 // Source should have some edits
1094 let src_patch = Patch::parse_unified_diff(&case.edit_history[0]);
1095 assert!(src_patch.stats().added > 0);
1096 }
1097
1098 #[test]
1099 fn test_split_point_index() {
1100 let commit = r#"// Change
1101--- a/test.rs
1102+++ b/test.rs
1103@@ -1,5 +1,10 @@
1104 fn main() {
1105+ line1();
1106+ line2();
1107+ line3();
1108+ line4();
1109+ line5();
1110 }
1111"#;
1112
1113 // Split at index 2 should give first 2 edits in source
1114 // With pure insertion handling, source gets 2 original + 1 partial = 3 additions
1115 let result = generate_evaluation_example_from_ordered_commit(
1116 commit,
1117 "",
1118 "hash",
1119 Some(SplitPoint::Index(2)),
1120 Some(1),
1121 );
1122
1123 assert!(result.is_ok());
1124 let case = result.unwrap();
1125
1126 let src_patch = Patch::parse_unified_diff(&case.edit_history[0]);
1127 // Pure insertion adds a partial line, so we expect 3 (2 original + 1 partial)
1128 assert_eq!(src_patch.stats().added, 3);
1129 }
1130
1131 #[test]
1132 fn test_cursor_excerpt_contains_marker() {
1133 let commit = r#"////////////////////////////////////////////////////////////////////////////////
1134// Add code
1135////////////////////////////////////////////////////////////////////////////////
1136--- a/test.rs
1137+++ b/test.rs
1138@@ -1,3 +1,5 @@
1139 fn main() {
1140+ println!("hello");
1141+ println!("world");
1142 }
1143"#;
1144
1145 let result = generate_evaluation_example_from_ordered_commit(
1146 commit,
1147 "",
1148 "hash",
1149 Some(SplitPoint::Fraction(0.5)),
1150 Some(42),
1151 )
1152 .unwrap();
1153
1154 // Cursor excerpt should contain the cursor marker
1155 assert!(
1156 result.cursor_excerpt.contains("<|user_cursor|>"),
1157 "Cursor excerpt should contain marker: {}",
1158 result.cursor_excerpt
1159 );
1160 }
1161
1162 #[test]
1163 fn test_evaluation_case_json_serialization() {
1164 let case = EvaluationCase {
1165 repository_url: "https://github.com/test/repo".to_string(),
1166 commit: "abc123~1".to_string(),
1167 edit_history: vec!["patch1".to_string()],
1168 cursor_position: "file.rs:10:5".to_string(),
1169 cursor_excerpt: "some code<|user_cursor|>".to_string(),
1170 expected_hunks: vec!["hunk1".to_string()],
1171 expected_patch: "patch".to_string(),
1172 allowed_patch: "patch".to_string(),
1173 expected_context_excerpts: vec![],
1174 extra: serde_json::json!({}),
1175 };
1176
1177 let json = serde_json::to_string(&case).unwrap();
1178 let deserialized: EvaluationCase = serde_json::from_str(&json).unwrap();
1179
1180 assert_eq!(case.repository_url, deserialized.repository_url);
1181 assert_eq!(case.commit, deserialized.commit);
1182 assert_eq!(case.cursor_position, deserialized.cursor_position);
1183 }
1184
1185 #[test]
1186 fn test_empty_commit_returns_error() {
1187 let commit = "";
1188
1189 let result = generate_evaluation_example_from_ordered_commit(
1190 commit,
1191 "",
1192 "hash",
1193 Some(SplitPoint::Fraction(0.5)),
1194 Some(1),
1195 );
1196
1197 assert!(result.is_err());
1198 }
1199
1200 #[test]
1201 fn test_header_filtering() {
1202 let commit = r#"commit abc123
1203Author: Test
1204Date: Today
1205
1206 Message
1207
1208diff --git a/test.rs b/test.rs
1209index 123..456 789
1210////////////////////////////////////////////////////////////////////////////////
1211// First group
1212////////////////////////////////////////////////////////////////////////////////
1213--- a/test.rs
1214+++ b/test.rs
1215@@ -1,3 +1,4 @@
1216 fn main() {
1217+ code();
1218 }
1219"#;
1220
1221 let result = generate_evaluation_example_from_ordered_commit(
1222 commit,
1223 "",
1224 "hash",
1225 Some(SplitPoint::Index(1)),
1226 Some(1),
1227 );
1228
1229 assert!(result.is_ok());
1230 let case = result.unwrap();
1231
1232 // The edit history should contain the group header (// lines)
1233 // but not the commit metadata
1234 assert!(!case.edit_history[0].contains("Author:"));
1235 assert!(!case.edit_history[0].contains("Date:"));
1236 }
1237
1238 #[test]
1239 fn test_position_weight() {
1240 // High weight positions (natural pause points)
1241 assert_eq!(position_weight("foo(", 4), 10); // After '('
1242 assert_eq!(position_weight("a, b", 2), 10); // After ','
1243 assert_eq!(position_weight("x;", 2), 10); // After ';'
1244 assert_eq!(position_weight("a: b", 2), 10); // After ':'
1245 assert_eq!(position_weight("[", 1), 10); // After '['
1246 assert_eq!(position_weight("{", 1), 10); // After '{'
1247
1248 // High weight for closing brackets
1249 assert_eq!(position_weight("foo)", 4), 8); // After ')'
1250 assert_eq!(position_weight("]", 1), 8); // After ']'
1251 assert_eq!(position_weight("}", 1), 8); // After '}'
1252
1253 // High weight at end of identifier
1254 assert_eq!(position_weight("foo ", 3), 8); // End of 'foo' before space
1255 assert_eq!(position_weight("bar(", 3), 8); // End of 'bar' before '('
1256
1257 // Medium weight for operators
1258 assert_eq!(position_weight("a + b", 3), 5); // After '+'
1259 assert_eq!(position_weight("x.", 2), 5); // After '.'
1260 assert_eq!(position_weight("a=b", 2), 5); // After '='
1261
1262 // Medium weight for whitespace
1263 assert_eq!(position_weight("a ", 2), 6); // After space
1264
1265 // Low weight mid-identifier
1266 assert_eq!(position_weight("foobar", 3), 1); // Mid-identifier 'foo|bar'
1267
1268 // Edge cases
1269 assert_eq!(position_weight("", 0), 1); // Empty string
1270 assert_eq!(position_weight("a", 0), 1); // Position 0
1271 }
1272
1273 #[test]
1274 fn test_weighted_select() {
1275 // Test that weighted selection returns correct indices
1276 let weights = vec![1, 10, 1];
1277
1278 // With total weight 12, seed 0 should select index 0
1279 // seed 0 % 12 = 0, cumulative: 1 at idx 0, so returns 0
1280 assert_eq!(weighted_select(&weights, 0), 0);
1281
1282 // seed 1 % 12 = 1, cumulative: 1 at idx 0 (1 < 1 is false), 11 at idx 1 (1 < 11 is true)
1283 assert_eq!(weighted_select(&weights, 1), 1);
1284
1285 // seed 10 % 12 = 10, cumulative: 1, 11 at idx 1 (10 < 11 is true)
1286 assert_eq!(weighted_select(&weights, 10), 1);
1287
1288 // seed 11 % 12 = 11, cumulative: 1, 11 at idx 1 (11 < 11 is false), 12 at idx 2 (11 < 12 is true)
1289 assert_eq!(weighted_select(&weights, 11), 2);
1290
1291 // Empty weights should return 0
1292 let empty: Vec<u32> = vec![];
1293 assert_eq!(weighted_select(&empty, 42), 0);
1294
1295 // Single weight should always return index 0
1296 let single = vec![10];
1297 assert_eq!(weighted_select(&single, 0), 0);
1298 assert_eq!(weighted_select(&single, 100), 0);
1299 }
1300
1301 #[test]
1302 fn test_weighted_split_prefers_natural_boundaries() {
1303 // Test that with different seeds, weighted selection tends to prefer
1304 // positions after punctuation over mid-identifier positions
1305 let text_with_punctuation = "foo(bar, baz)";
1306 let text_mid_identifier = "foobar";
1307
1308 // Position after '(' should have high weight
1309 let weight_after_paren = position_weight(text_with_punctuation, 4);
1310 // Position after ',' should have high weight
1311 let weight_after_comma = position_weight(text_with_punctuation, 8);
1312 // Position mid-identifier should have low weight
1313 let weight_mid_ident = position_weight(text_mid_identifier, 3);
1314
1315 assert!(
1316 weight_after_paren > weight_mid_ident,
1317 "After '(' ({}) should be weighted higher than mid-identifier ({})",
1318 weight_after_paren,
1319 weight_mid_ident
1320 );
1321 assert!(
1322 weight_after_comma > weight_mid_ident,
1323 "After ',' ({}) should be weighted higher than mid-identifier ({})",
1324 weight_after_comma,
1325 weight_mid_ident
1326 );
1327 }
1328
1329 #[test]
1330 fn test_imitate_human_edits_pure_insertion() {
1331 // Source patch is empty (no edits yet)
1332 // Target patch has a pure insertion (adding a new line)
1333 let source = r#"--- a/test.rs
1334+++ b/test.rs
1335@@ -1,2 +1,2 @@
1336 fn main() {
1337 }
1338"#;
1339 let target = r#"--- a/test.rs
1340+++ b/test.rs
1341@@ -1,2 +1,3 @@
1342 fn main() {
1343+ println!("debug");
1344 }
1345"#;
1346
1347 let (new_src, new_tgt, cursor) = imitate_human_edits(source, target, 42);
1348
1349 // Should have transformed the patches
1350 assert_ne!(
1351 new_src, source,
1352 "Source should be modified for pure insertion"
1353 );
1354 assert_ne!(
1355 new_tgt, target,
1356 "Target should be modified for pure insertion"
1357 );
1358 assert!(cursor.is_some(), "Cursor should be set");
1359
1360 // Source should now have a partial addition
1361 let src_patch = Patch::parse_unified_diff(&new_src);
1362 assert!(
1363 src_patch.stats().added > 0,
1364 "Source should have added lines"
1365 );
1366
1367 // Target should have both a deletion (of partial) and addition (of full)
1368 let tgt_patch = Patch::parse_unified_diff(&new_tgt);
1369 assert!(
1370 tgt_patch.stats().removed > 0,
1371 "Target should have removed lines (partial)"
1372 );
1373 assert!(
1374 tgt_patch.stats().added > 0,
1375 "Target should have added lines (full)"
1376 );
1377
1378 // The cursor should be in test.rs
1379 let cursor = cursor.unwrap();
1380 assert_eq!(cursor.file, "test.rs");
1381 }
1382
1383 #[test]
1384 fn test_imitate_human_edits_pure_insertion_empty_source() {
1385 // Source patch has no hunks at all
1386 let source = "";
1387 let target = r#"--- a/test.rs
1388+++ b/test.rs
1389@@ -1,2 +1,3 @@
1390 fn main() {
1391+ println!("hello");
1392 }
1393"#;
1394
1395 let (new_src, _new_tgt, cursor) = imitate_human_edits(source, target, 123);
1396
1397 // Should have created a source patch with partial insertion
1398 assert!(!new_src.is_empty(), "Source should not be empty");
1399 assert!(cursor.is_some(), "Cursor should be set");
1400
1401 let src_patch = Patch::parse_unified_diff(&new_src);
1402 assert!(
1403 src_patch.stats().added > 0,
1404 "Source should have added lines"
1405 );
1406 }
1407
1408 #[test]
1409 fn test_imitate_human_edits_pure_insertion_intermediate_content() {
1410 // Verify the actual intermediate content is a realistic partial typing state
1411 let source = "";
1412 let target = r#"--- a/test.rs
1413+++ b/test.rs
1414@@ -1,2 +1,3 @@
1415 fn main() {
1416+ println!("hello world");
1417 }
1418"#;
1419
1420 // Test with multiple seeds to see different split points
1421 let mut found_partial = false;
1422 for seed in 1..=50 {
1423 let (new_src, new_tgt, cursor) = imitate_human_edits(source, target, seed);
1424
1425 if cursor.is_some() {
1426 let src_patch = Patch::parse_unified_diff(&new_src);
1427 let tgt_patch = Patch::parse_unified_diff(&new_tgt);
1428
1429 // Find the added line in source
1430 for hunk in &src_patch.hunks {
1431 for line in &hunk.lines {
1432 if let PatchLine::Addition(content) = line {
1433 // The partial line should be a prefix of the full line
1434 let full_line = " println!(\"hello world\");";
1435 if content != full_line && full_line.starts_with(content) {
1436 found_partial = true;
1437
1438 // Verify target has the partial as deletion
1439 let mut has_deletion = false;
1440 for tgt_hunk in &tgt_patch.hunks {
1441 for tgt_line in &tgt_hunk.lines {
1442 if let PatchLine::Deletion(del_content) = tgt_line {
1443 if del_content == content {
1444 has_deletion = true;
1445 }
1446 }
1447 }
1448 }
1449 assert!(
1450 has_deletion,
1451 "Target should have deletion of partial line"
1452 );
1453 }
1454 }
1455 }
1456 }
1457 }
1458 }
1459
1460 assert!(
1461 found_partial,
1462 "At least one seed should produce a partial intermediate state"
1463 );
1464 }
1465}