1//! `ep split-commit` implementation.
2//!
3//! This command generates a single evaluation example JSON object from a
4//! chronologically-ordered unified diff (a "commit").
5//!
6//! TODO: Port Python code to generate chronologically-ordered commits
7use crate::FailedHandling;
8use crate::reorder_patch::{Patch, PatchLine, extract_edits, locate_edited_line};
9
10/// Find the largest valid UTF-8 char boundary at or before `index` in `s`.
11fn floor_char_boundary(s: &str, index: usize) -> usize {
12 if index >= s.len() {
13 s.len()
14 } else if s.is_char_boundary(index) {
15 index
16 } else {
17 // Find the nearest valid character boundary at or before index
18 (0..index)
19 .rev()
20 .find(|&i| s.is_char_boundary(i))
21 .unwrap_or(0)
22 }
23}
24use anyhow::{Context as _, Result};
25use clap::Args;
26use edit_prediction::example_spec::ExampleSpec;
27use rand::Rng;
28use rand::SeedableRng;
29use serde::{Deserialize, Serialize};
30use similar::{DiffTag, TextDiff};
31use std::collections::BTreeSet;
32use std::fs;
33use std::io::{self, Write};
34use std::path::Path;
35use std::path::PathBuf;
36
37/// `ep split-commit` CLI args.
38#[derive(Debug, Args, Clone)]
39pub struct SplitCommitArgs {
40 /// Split point (float 0.0-1.0 for fraction, or integer for index)
41 #[arg(long, short = 's')]
42 pub split_point: Option<String>,
43
44 /// Random seed for reproducibility
45 #[arg(long)]
46 pub seed: Option<u64>,
47
48 /// Pretty-print JSON output
49 #[arg(long, short = 'p')]
50 pub pretty: bool,
51
52 /// Number of samples to generate per commit (samples random split points)
53 #[arg(long, short = 'n')]
54 pub num_samples: Option<usize>,
55}
56
57/// Input format for annotated commits (JSON Lines).
58#[derive(Debug, Clone, Deserialize)]
59#[allow(dead_code)]
60pub struct AnnotatedCommit {
61 /// Repository path (e.g., "repos/zed")
62 pub repo: String,
63 /// Repository URL (e.g., "https://github.com/zed-industries/zed")
64 pub repo_url: String,
65 /// Commit SHA
66 pub commit_sha: String,
67 /// Chronologically reordered commit diff
68 pub reordered_commit: String,
69 /// Original commit diff
70 pub original_commit: String,
71 /// Whether diff stats match between original and reordered
72 pub diff_stats_match: bool,
73}
74
75/// Cursor position in a file.
76#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
77pub struct CursorPosition {
78 pub file: String,
79 pub line: usize,
80 pub column: usize,
81}
82
83impl std::fmt::Display for CursorPosition {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 write!(f, "{}:{}:{}", self.file, self.line, self.column)
86 }
87}
88
89/// Represents a split commit with source and target patches.
90#[derive(Debug, Clone)]
91pub struct SplitCommit {
92 pub source_patch: String,
93 pub target_patch: String,
94}
95
96/// Split point specification for evaluation generation.
97#[derive(Debug, Clone)]
98pub enum SplitPoint {
99 /// Fraction of total edits (0.0 to 1.0)
100 Fraction(f64),
101 /// Absolute index
102 Index(usize),
103}
104
105fn parse_split_point(value: &str) -> Option<SplitPoint> {
106 if value.contains('.') {
107 value.parse::<f64>().ok().map(SplitPoint::Fraction)
108 } else {
109 value.parse::<usize>().ok().map(SplitPoint::Index)
110 }
111}
112
113/// Entry point for the `ep split-commit` subcommand.
114///
115/// This runs synchronously and outputs JSON Lines (one output per input line).
116pub fn run_split_commit(
117 args: &SplitCommitArgs,
118 inputs: &[PathBuf],
119 output_path: Option<&PathBuf>,
120 failed: FailedHandling,
121) -> Result<()> {
122 use std::collections::HashSet;
123 use std::io::BufRead;
124
125 let stdin_path = PathBuf::from("-");
126 let inputs = if inputs.is_empty() {
127 std::slice::from_ref(&stdin_path)
128 } else {
129 inputs
130 };
131
132 let split_point = args.split_point.as_deref().and_then(parse_split_point);
133 let mut output_lines = Vec::new();
134
135 for input_path in inputs {
136 let input: Box<dyn BufRead> = if input_path.as_os_str() == "-" {
137 Box::new(io::BufReader::new(io::stdin()))
138 } else {
139 let file = fs::File::open(input_path)
140 .with_context(|| format!("failed to open input file {}", input_path.display()))?;
141 Box::new(io::BufReader::new(file))
142 };
143
144 for (line_num, line_result) in input.lines().enumerate() {
145 let line =
146 line_result.with_context(|| format!("failed to read line {}", line_num + 1))?;
147
148 if line.trim().is_empty() {
149 continue;
150 }
151
152 let annotated: AnnotatedCommit = serde_json::from_str(&line)
153 .with_context(|| format!("failed to parse JSON at line {}", line_num + 1))?;
154
155 // Generate multiple samples if num_samples is set
156 if let Some(num_samples) = args.num_samples {
157 let mut seen_samples: HashSet<String> = HashSet::new();
158 let base_seed = args.seed.unwrap_or_else(|| rand::random());
159
160 for sample_idx in 0..num_samples {
161 let sample_seed = base_seed.wrapping_add(sample_idx as u64);
162
163 let case = match generate_evaluation_example_from_ordered_commit(
164 &annotated.reordered_commit,
165 &annotated.repo_url,
166 &annotated.commit_sha,
167 None, // Use random split point for multi-sample mode
168 Some(sample_seed),
169 Some(sample_idx),
170 ) {
171 Ok(case) => case,
172 Err(e) => {
173 let err_msg = format!(
174 "failed to generate evaluation example for commit {} at line {} (sample {}): {}",
175 annotated.commit_sha,
176 line_num + 1,
177 sample_idx,
178 e
179 );
180 match failed {
181 FailedHandling::Skip | FailedHandling::SkipNoFiles => {
182 eprintln!("{}", err_msg);
183 continue;
184 }
185 FailedHandling::Keep => {
186 anyhow::bail!(err_msg);
187 }
188 }
189 }
190 };
191
192 let json = if args.pretty {
193 serde_json::to_string_pretty(&case)
194 } else {
195 serde_json::to_string(&case)
196 }
197 .context("failed to serialize evaluation case as JSON")?;
198
199 // Only add unique samples (different split points may produce same result)
200 if seen_samples.insert(json.clone()) {
201 output_lines.push(json);
202 }
203 }
204 } else {
205 let case = match generate_evaluation_example_from_ordered_commit(
206 &annotated.reordered_commit,
207 &annotated.repo_url,
208 &annotated.commit_sha,
209 split_point.clone(),
210 args.seed,
211 None,
212 ) {
213 Ok(case) => case,
214 Err(e) => {
215 let err_msg = format!(
216 "failed to generate evaluation example for commit {} at line {}: {}",
217 annotated.commit_sha,
218 line_num + 1,
219 e
220 );
221 match failed {
222 FailedHandling::Skip | FailedHandling::SkipNoFiles => {
223 eprintln!("{}", err_msg);
224 continue;
225 }
226 FailedHandling::Keep => {
227 anyhow::bail!(err_msg);
228 }
229 }
230 }
231 };
232
233 let json = if args.pretty {
234 serde_json::to_string_pretty(&case)
235 } else {
236 serde_json::to_string(&case)
237 }
238 .context("failed to serialize evaluation case as JSON")?;
239
240 output_lines.push(json);
241 }
242 }
243 }
244
245 let output_content = output_lines.join("\n") + if output_lines.is_empty() { "" } else { "\n" };
246
247 if let Some(path) = output_path {
248 fs::write(path, &output_content)
249 .with_context(|| format!("failed to write output to {}", path.display()))?;
250 } else {
251 io::stdout()
252 .write_all(output_content.as_bytes())
253 .context("failed to write to stdout")?;
254 }
255
256 Ok(())
257}
258
259/// Main function to generate an evaluation example from an ordered commit.
260///
261/// # Arguments
262/// * `commit` - Chronologically ordered unified diff of the commit
263/// * `repository_url` - URL of the repository
264/// * `commit_hash` - Hash of the commit
265/// * `split_point` - Point at which the commit will be split (None for random)
266/// * `seed` - Optional seed for randomness
267/// * `sample_num` - Optional sample number for generating unique names
268pub fn generate_evaluation_example_from_ordered_commit(
269 commit: &str,
270 repository_url: &str,
271 commit_hash: &str,
272 split_point: Option<SplitPoint>,
273 seed: Option<u64>,
274 sample_num: Option<usize>,
275) -> Result<ExampleSpec> {
276 let mut rng: Box<dyn rand::RngCore> = match seed {
277 Some(seed) => Box::new(rand::rngs::StdRng::seed_from_u64(seed)),
278 None => Box::new(rand::rngs::ThreadRng::default()),
279 };
280
281 // Parse and normalize the commit
282 let mut patch = Patch::parse_unified_diff(commit);
283
284 // Filter header to only keep lines starting with "//"
285 let header_lines: Vec<&str> = patch
286 .header
287 .lines()
288 .filter(|line| line.starts_with("//"))
289 .collect();
290 patch.header = if header_lines.is_empty() {
291 String::new()
292 } else {
293 header_lines.join("\n") + "\n"
294 };
295 let commit_normalized = patch.to_string();
296
297 // Compute the split point
298 let stats = patch.stats();
299 let num_edits = stats.added + stats.removed;
300
301 anyhow::ensure!(num_edits != 0, "no edits found in commit");
302
303 let split = match split_point {
304 None => rng.random_range(1..=num_edits),
305 Some(SplitPoint::Fraction(f)) => {
306 let v = (f * num_edits as f64).floor() as usize;
307 v.min(num_edits)
308 }
309 Some(SplitPoint::Index(i)) => i.min(num_edits),
310 };
311
312 // Split the commit into source and target patches
313 let (prefix, suffix) = split_ordered_commit(&commit_normalized, split);
314
315 let mut split_commit = SplitCommit {
316 source_patch: prefix,
317 target_patch: suffix,
318 };
319
320 // Imitate human edits
321 let human_edit_seed = rng.random_range(1..=10000u64);
322 let (src_patch, tgt_patch, cursor_opt) = imitate_human_edits(
323 &split_commit.source_patch,
324 &split_commit.target_patch,
325 human_edit_seed,
326 );
327 split_commit.source_patch = src_patch;
328 split_commit.target_patch = tgt_patch;
329
330 // Sample cursor position
331 let cursor = match cursor_opt {
332 Some(c) => c,
333 None => sample_cursor_position(&patch, &split_commit)
334 .context("failed to sample cursor position")?,
335 };
336
337 // Get cursor excerpt
338 let cursor_excerpt = get_cursor_excerpt(
339 &cursor,
340 &split_commit.source_patch,
341 &split_commit.target_patch,
342 )
343 .context("failed to generate cursor excerpt")?;
344
345 // Handle edge case where split_point == 0
346 if split == 0 {
347 split_commit.target_patch = String::new();
348 }
349
350 let repo_name = repository_url
351 .trim_end_matches('/')
352 .rsplit('/')
353 .next()
354 .unwrap_or("unknown");
355 let short_sha = &commit_hash[..commit_hash.len().min(8)];
356 let name = match sample_num {
357 Some(n) => format!("{}-{}-{}", repo_name, short_sha, n),
358 None => format!("{}-{}", repo_name, short_sha),
359 };
360
361 Ok(ExampleSpec {
362 name,
363 repository_url: repository_url.to_string(),
364 revision: format!("{}~1", commit_hash),
365 edit_history: split_commit.source_patch.clone(),
366 cursor_path: Path::new(&cursor.file).into(),
367 cursor_position: cursor_excerpt,
368 expected_patches: vec![split_commit.target_patch],
369 tags: vec![],
370 reasoning: None,
371 uncommitted_diff: String::new(),
372 rejected_patch: None,
373 captured_prompt_input: None,
374 telemetry: None,
375 human_feedback: Vec::new(),
376 rating: None,
377 })
378}
379
380/// Split an ordered commit into source and target commits.
381///
382/// # Arguments
383/// * `commit` - Ordered commit string
384/// * `split_pos` - Position to split the commit (number of edited lines)
385///
386/// # Returns
387/// A tuple of (source_diff, target_diff)
388pub fn split_ordered_commit(commit: &str, split_pos: usize) -> (String, String) {
389 let patch = Patch::parse_unified_diff(commit);
390 let source_edits: BTreeSet<usize> = (0..split_pos).collect();
391 let (source, target) = extract_edits(&patch, &source_edits);
392
393 let mut source_str = source.to_string();
394 let target_str = target.to_string();
395
396 // Strip last group header from the source (lines starting with "//" at the end)
397 let source_lines: Vec<&str> = source_str.lines().collect();
398 let mut end_idx = source_lines.len();
399 for i in (0..source_lines.len()).rev() {
400 if source_lines[i].starts_with("//") {
401 end_idx = i;
402 } else {
403 break;
404 }
405 }
406 if end_idx < source_lines.len() {
407 source_str = source_lines[..end_idx].join("\n");
408 if !source_str.is_empty() {
409 source_str.push('\n');
410 }
411 }
412
413 (source_str, target_str)
414}
415
416/// Tokenize text into words and non-word characters.
417fn tokenize(text: &str) -> Vec<String> {
418 let mut tokens = Vec::new();
419 let mut current = String::new();
420
421 for ch in text.chars() {
422 if ch.is_alphanumeric() {
423 current.push(ch);
424 } else if ch == '_' {
425 // Include underscore with the current word, then flush
426 current.push(ch);
427 if !current.is_empty() {
428 tokens.push(std::mem::take(&mut current));
429 }
430 } else {
431 // Punctuation or whitespace - flush current word first
432 if !current.is_empty() {
433 tokens.push(std::mem::take(&mut current));
434 }
435 // Each punctuation/whitespace is its own token
436 tokens.push(ch.to_string());
437 }
438 }
439
440 if !current.is_empty() {
441 tokens.push(current);
442 }
443
444 tokens
445}
446
447/// Calculate the weight for a split position based on the character at that position.
448///
449/// Higher weights indicate more natural pause points (e.g., after punctuation,
450/// at identifier boundaries). Lower weights indicate less natural points
451/// (e.g., mid-identifier).
452fn position_weight(text: &str, pos: usize) -> u32 {
453 if pos == 0 || pos > text.len() {
454 return 1;
455 }
456
457 let chars: Vec<char> = text.chars().collect();
458 if pos > chars.len() {
459 return 1;
460 }
461
462 // Get the character just before this position (what we just "typed")
463 let prev_char = chars[pos - 1];
464
465 // High weight: natural pause points (end of statement/argument, opening brackets)
466 if matches!(prev_char, ',' | ';' | ':' | '(' | '[' | '{') {
467 return 10;
468 }
469
470 // High weight: closing brackets (finished a group)
471 if matches!(prev_char, ')' | ']' | '}') {
472 return 8;
473 }
474
475 // Medium weight: operators and method chains
476 if matches!(
477 prev_char,
478 '.' | '+' | '-' | '*' | '/' | '=' | '<' | '>' | '&' | '|' | '!'
479 ) {
480 return 5;
481 }
482
483 // Check if we're at the end of an identifier (word char followed by non-word char)
484 let is_prev_word_char = prev_char.is_alphanumeric() || prev_char == '_';
485 let is_next_word_char =
486 pos < chars.len() && (chars[pos].is_alphanumeric() || chars[pos] == '_');
487
488 if is_prev_word_char && !is_next_word_char {
489 // End of identifier - high weight
490 return 8;
491 }
492
493 // Whitespace is a natural pause
494 if prev_char.is_whitespace() {
495 return 6;
496 }
497
498 // Mid-identifier: low weight (rare autocomplete scenarios)
499 if is_prev_word_char && is_next_word_char {
500 return 1;
501 }
502
503 // Default medium-low weight
504 3
505}
506
507/// Select a weighted random index from a list of weights.
508///
509/// Returns an index based on the weights, using the provided seed for
510/// deterministic selection.
511fn weighted_select(weights: &[u32], seed: u64) -> usize {
512 if weights.is_empty() {
513 return 0;
514 }
515
516 let total_weight: u64 = weights.iter().map(|&w| w as u64).sum();
517 if total_weight == 0 {
518 // Fallback to uniform selection if all weights are zero
519 return seed as usize % weights.len();
520 }
521
522 // Use seed to select a value in [0, total_weight)
523 let target = seed % total_weight;
524 let mut cumulative: u64 = 0;
525
526 for (idx, &weight) in weights.iter().enumerate() {
527 cumulative += weight as u64;
528 if target < cumulative {
529 return idx;
530 }
531 }
532
533 // Fallback to last index
534 weights.len() - 1
535}
536
537/// Calculate similarity ratio between two strings (0-100).
538fn fuzzy_ratio(s1: &str, s2: &str) -> u32 {
539 if s1.is_empty() && s2.is_empty() {
540 return 100;
541 }
542 if s1.is_empty() || s2.is_empty() {
543 return 0;
544 }
545
546 let diff = TextDiff::from_chars(s1, s2);
547 let matching: usize = diff
548 .ops()
549 .iter()
550 .filter_map(|op| {
551 if matches!(op.tag(), DiffTag::Equal) {
552 Some(op.new_range().len())
553 } else {
554 None
555 }
556 })
557 .sum();
558
559 let total = s1.len() + s2.len();
560 ((2 * matching * 100) / total) as u32
561}
562
563/// Imitate human edits by introducing partial line edits.
564///
565/// This function simulates how a human might incrementally type code,
566/// rather than making complete line replacements.
567pub fn imitate_human_edits(
568 source_patch: &str,
569 target_patch: &str,
570 seed: u64,
571) -> (String, String, Option<CursorPosition>) {
572 let no_change = (source_patch.to_string(), target_patch.to_string(), None);
573
574 let src_patch = Patch::parse_unified_diff(source_patch);
575 let tgt_patch = Patch::parse_unified_diff(target_patch);
576
577 if tgt_patch.hunks.is_empty() {
578 return no_change;
579 }
580
581 // Try to locate the first edit in target
582 let tgt_edit_loc = match locate_edited_line(&tgt_patch, 0) {
583 Some(loc) => loc,
584 None => return no_change,
585 };
586
587 let tgt_is_addition = matches!(tgt_edit_loc.patch_line, PatchLine::Addition(_));
588 if !tgt_is_addition {
589 return no_change;
590 }
591
592 let tgt_line = match &tgt_edit_loc.patch_line {
593 PatchLine::Addition(s) => s.clone(),
594 _ => return no_change,
595 };
596
597 // Try to locate the last edit in source
598 let src_edit_loc = locate_edited_line(&src_patch, -1);
599
600 // Check if source has ANY edit at the same line as target's first edit
601 // We need to iterate through all edits to check this
602 let src_has_edit_at_target_line = {
603 let mut found = false;
604 let mut idx = 0isize;
605 while let Some(loc) = locate_edited_line(&src_patch, idx) {
606 if loc.filename == tgt_edit_loc.filename
607 && loc.target_line_number == tgt_edit_loc.target_line_number
608 {
609 found = true;
610 break;
611 }
612 idx += 1;
613 }
614 found
615 };
616
617 // Check if this is a replacement (deletion followed by insertion on the same line)
618 // or a pure insertion (no corresponding deletion in source)
619 let is_replacement = src_edit_loc.as_ref().map_or(false, |loc| {
620 matches!(loc.patch_line, PatchLine::Deletion(_))
621 && loc.filename == tgt_edit_loc.filename
622 && loc.target_line_number == tgt_edit_loc.target_line_number
623 });
624
625 // If source has an edit at the same line but it's not a replacement (i.e., it's an addition),
626 // we shouldn't process this as a pure insertion either
627 if !is_replacement && src_has_edit_at_target_line {
628 return no_change;
629 }
630
631 let src_line = if is_replacement {
632 match &src_edit_loc.as_ref().unwrap().patch_line {
633 PatchLine::Deletion(s) => s.clone(),
634 _ => return no_change,
635 }
636 } else {
637 // Pure insertion: source line is empty
638 String::new()
639 };
640
641 // Don't process if source and target are the same
642 if src_line == tgt_line {
643 return no_change;
644 }
645
646 // Tokenize both lines
647 let src_tokens = tokenize(&src_line);
648 let tgt_tokens = tokenize(&tgt_line);
649
650 // Convert to slices for similar
651 let src_refs: Vec<&str> = src_tokens.iter().map(|s| s.as_str()).collect();
652 let tgt_refs: Vec<&str> = tgt_tokens.iter().map(|s| s.as_str()).collect();
653
654 // Use similar to get diff operations
655 let diff = TextDiff::from_slices(&src_refs, &tgt_refs);
656
657 // Build weights for each possible split position
658 let mut position_weights: Vec<u32> = Vec::new();
659
660 // Simulate the edit process to collect weights for all possible split positions
661 {
662 let mut current_text = String::new();
663
664 for op in diff.ops() {
665 match op.tag() {
666 DiffTag::Equal => {
667 for i in op.old_range() {
668 current_text.push_str(&src_tokens[i]);
669 }
670 }
671 DiffTag::Replace => {
672 let ins: String = op.new_range().map(|i| tgt_tokens[i].as_str()).collect();
673 let del: String = op.old_range().map(|i| src_tokens[i].as_str()).collect();
674
675 // For insertion part
676 for ch in ins.chars() {
677 current_text.push(ch);
678 let weight = position_weight(¤t_text, current_text.len());
679 position_weights.push(weight);
680 }
681
682 // For deletion part (we're "untyping" from source)
683 for _ in del.chars() {
684 // Weight deletions lower as they represent removing text
685 position_weights.push(2);
686 }
687 }
688 DiffTag::Insert => {
689 let ins: String = op.new_range().map(|i| tgt_tokens[i].as_str()).collect();
690 for ch in ins.chars() {
691 current_text.push(ch);
692 let weight = position_weight(¤t_text, current_text.len());
693 position_weights.push(weight);
694 }
695 }
696 DiffTag::Delete => {
697 let del: String = op.old_range().map(|i| src_tokens[i].as_str()).collect();
698 for _ in del.chars() {
699 // Weight deletions lower
700 position_weights.push(2);
701 }
702 }
703 }
704 }
705 }
706
707 // Use weighted selection to choose split index
708 if position_weights.is_empty() {
709 return no_change;
710 }
711 let split_index = weighted_select(&position_weights, seed);
712
713 let mut edit_index = 0usize;
714 let mut new_src = String::new();
715 let mut split_found = false;
716 let mut last_old_end = 0usize;
717
718 for op in diff.ops() {
719 match op.tag() {
720 DiffTag::Equal => {
721 for i in op.old_range() {
722 new_src.push_str(&src_tokens[i]);
723 }
724 last_old_end = op.old_range().end;
725 }
726 DiffTag::Replace => {
727 // Handle replace as delete + insert
728 let del: String = op.old_range().map(|i| src_tokens[i].as_str()).collect();
729 let ins: String = op.new_range().map(|i| tgt_tokens[i].as_str()).collect();
730 let repl_len = del.len() + ins.len();
731 if edit_index + repl_len >= split_index {
732 // Split within this replace operation
733 let offset = split_index - edit_index;
734 if offset < ins.len() {
735 let safe_offset = floor_char_boundary(&ins, offset);
736 new_src.push_str(&ins[..safe_offset]);
737 } else {
738 new_src.push_str(&ins);
739 let del_offset = offset - ins.len();
740 let safe_del_offset = floor_char_boundary(&del, del_offset.min(del.len()));
741 new_src.push_str(&del[..safe_del_offset]);
742 }
743 split_found = true;
744 last_old_end = op.old_range().end;
745 break;
746 } else {
747 edit_index += repl_len;
748 new_src.push_str(&ins);
749 last_old_end = op.old_range().end;
750 }
751 }
752 DiffTag::Insert => {
753 let repl: String = op.new_range().map(|i| tgt_tokens[i].as_str()).collect();
754 if edit_index + repl.len() >= split_index {
755 let offset = split_index - edit_index;
756 let safe_offset = floor_char_boundary(&repl, offset);
757 new_src.push_str(&repl[..safe_offset]);
758 split_found = true;
759 break;
760 } else {
761 edit_index += repl.len();
762 new_src.push_str(&repl);
763 }
764 }
765 DiffTag::Delete => {
766 let repl: String = op.old_range().map(|i| src_tokens[i].as_str()).collect();
767 if edit_index + repl.len() >= split_index {
768 let offset = split_index - edit_index;
769 let safe_offset = floor_char_boundary(&repl, offset);
770 new_src.push_str(&repl[..safe_offset]);
771 split_found = true;
772 last_old_end = op.old_range().start + safe_offset.min(op.old_range().len());
773 break;
774 } else {
775 edit_index += repl.len();
776 new_src.push_str(&repl);
777 last_old_end = op.old_range().end;
778 }
779 }
780 }
781 }
782
783 if !split_found {
784 return no_change;
785 }
786
787 // Calculate cursor position
788 let cursor = CursorPosition {
789 file: tgt_edit_loc.filename.clone(),
790 line: if is_replacement {
791 src_edit_loc.as_ref().unwrap().source_line_number
792 } else {
793 tgt_edit_loc.target_line_number
794 },
795 column: new_src.len() + 1,
796 };
797
798 // Add remainder of source if similar enough to target remainder
799 let remainder_src: String = (last_old_end..src_tokens.len())
800 .map(|i| src_tokens[i].as_str())
801 .collect();
802 let remainder_tgt: String = (last_old_end..tgt_tokens.len())
803 .filter_map(|i| tgt_tokens.get(i).map(|s| s.as_str()))
804 .collect();
805
806 let ratio = fuzzy_ratio(&remainder_src, &remainder_tgt);
807 if ratio > 35 {
808 new_src.push_str(&remainder_src);
809 }
810
811 if new_src.trim().is_empty() {
812 return no_change;
813 }
814
815 if new_src == src_line {
816 return no_change;
817 }
818
819 // Build new source patch with the intermediate line
820 let mut new_src_patch = src_patch;
821 if is_replacement {
822 // For replacements, insert after the deletion line
823 let src_loc = src_edit_loc.as_ref().unwrap();
824 if let Some(hunk) = new_src_patch.hunks.get_mut(src_loc.hunk_index) {
825 hunk.lines.insert(
826 src_loc.line_index_within_hunk + 1,
827 PatchLine::Addition(new_src.clone()),
828 );
829 hunk.new_count += 1;
830 }
831 } else {
832 // For pure insertions, insert after the last edit in source patch
833 // This imitates human typing - the intermediate content is what the user is currently typing
834 let last_src_edit = locate_edited_line(&new_src_patch, -1);
835
836 if let Some(src_loc) = last_src_edit {
837 // Insert after the last edit in source
838 if let Some(hunk) = new_src_patch.hunks.get_mut(src_loc.hunk_index) {
839 hunk.lines.insert(
840 src_loc.line_index_within_hunk + 1,
841 PatchLine::Addition(new_src.clone()),
842 );
843 hunk.new_count += 1;
844 }
845 } else {
846 // Source patch is empty or has incompatible hunk structure, create a new hunk based on target
847 if let Some(tgt_hunk) = tgt_patch.hunks.get(tgt_edit_loc.hunk_index) {
848 let mut new_hunk = tgt_hunk.clone();
849 // Replace the full addition with the partial one
850 new_hunk.lines.clear();
851 for (i, line) in tgt_hunk.lines.iter().enumerate() {
852 if i == tgt_edit_loc.line_index_within_hunk {
853 new_hunk.lines.push(PatchLine::Addition(new_src.clone()));
854 } else {
855 match line {
856 PatchLine::Addition(_) => {
857 // Skip other additions from target
858 }
859 _ => new_hunk.lines.push(line.clone()),
860 }
861 }
862 }
863 new_hunk.new_count = new_hunk.old_count + 1;
864 new_src_patch.hunks.push(new_hunk);
865 // Copy header from target if source doesn't have one
866 if new_src_patch.header.is_empty() {
867 new_src_patch.header = tgt_patch.header.clone();
868 }
869 }
870 }
871 }
872
873 // Build new target patch with the intermediate line as deletion
874 let mut new_tgt_patch = tgt_patch;
875 if let Some(hunk) = new_tgt_patch.hunks.get_mut(tgt_edit_loc.hunk_index) {
876 hunk.lines.insert(
877 tgt_edit_loc.line_index_within_hunk,
878 PatchLine::Deletion(new_src),
879 );
880 hunk.old_count += 1;
881 }
882
883 (
884 new_src_patch.to_string(),
885 new_tgt_patch.to_string(),
886 Some(cursor),
887 )
888}
889
890/// Locate the end of the last edit in a patch.
891fn locate_end_of_last_edit(patch: &Patch) -> Option<CursorPosition> {
892 let loc = locate_edited_line(patch, -1)?;
893
894 let (line, col) = match &loc.patch_line {
895 PatchLine::Addition(content) => (loc.target_line_number, content.len()),
896 PatchLine::Deletion(_) => (loc.target_line_number, 1),
897 _ => return None,
898 };
899
900 Some(CursorPosition {
901 file: loc.filename,
902 line,
903 column: col,
904 })
905}
906
907/// Locate the beginning of the first edit in a patch.
908fn locate_beginning_of_first_edit(patch: &Patch) -> Option<CursorPosition> {
909 let loc = locate_edited_line(patch, 0)?;
910
911 let hunk = patch.hunks.get(loc.hunk_index)?;
912 let column = if loc.line_index_within_hunk > 0 {
913 if let Some(prev_line) = hunk.lines.get(loc.line_index_within_hunk - 1) {
914 let content = match prev_line {
915 PatchLine::Context(s) | PatchLine::Addition(s) | PatchLine::Deletion(s) => s,
916 _ => return None,
917 };
918 content.len().max(1) - 1
919 } else {
920 0
921 }
922 } else {
923 0
924 };
925
926 let line = loc.target_line_number.saturating_sub(1).max(1);
927
928 Some(CursorPosition {
929 file: loc.filename,
930 line,
931 column,
932 })
933}
934
935/// Sample cursor position according to the following rules:
936/// 1. 50% chance of cursor being at the end of the source patch
937/// 2. 50% chance of cursor being at the beginning of the target patch
938pub fn sample_cursor_position(patch: &Patch, split_commit: &SplitCommit) -> Option<CursorPosition> {
939 // Try end of history first
940 let src_patch = Patch::parse_unified_diff(&split_commit.source_patch);
941 if let Some(cursor) = locate_end_of_last_edit(&src_patch) {
942 return Some(cursor);
943 }
944
945 // Try beginning of target
946 let tgt_patch = Patch::parse_unified_diff(&split_commit.target_patch);
947 if let Some(cursor) = locate_beginning_of_first_edit(&tgt_patch) {
948 return Some(cursor);
949 }
950
951 // Fallback: use the original patch
952 locate_end_of_last_edit(patch)
953}
954
955/// Get cursor excerpt from the patches.
956///
957/// This extracts the lines around the cursor position with a cursor marker.
958pub fn get_cursor_excerpt(
959 cursor: &CursorPosition,
960 source_patch: &str,
961 target_patch: &str,
962) -> Option<String> {
963 let mut excerpt_lines: Vec<String> = Vec::new();
964 let mut excerpt_first_line: usize = 0;
965
966 // Search in the last hunk of source patch
967 let src = Patch::parse_unified_diff(source_patch);
968 if let Some(loc) = locate_edited_line(&src, -1) {
969 if loc.filename == cursor.file && loc.target_line_number == cursor.line {
970 if let Some(hunk) = src.hunks.get(loc.hunk_index) {
971 excerpt_first_line = hunk.new_start as usize;
972 for line in &hunk.lines {
973 match line {
974 PatchLine::Addition(s) | PatchLine::Context(s) => {
975 excerpt_lines.push(s.clone());
976 }
977 _ => {}
978 }
979 }
980 // If hunk only has deletions (file deletion), include deletion lines
981 if excerpt_lines.is_empty() {
982 excerpt_first_line = hunk.old_start as usize;
983 for line in &hunk.lines {
984 match line {
985 PatchLine::Deletion(s) => {
986 excerpt_lines.push(s.clone());
987 }
988 _ => {}
989 }
990 }
991 }
992 }
993 }
994 }
995
996 // Search in target patch if not found
997 if excerpt_lines.is_empty() {
998 let tgt = Patch::parse_unified_diff(target_patch);
999 // Search all hunks for the cursor file, not just the first edit's hunk
1000 for hunk in &tgt.hunks {
1001 if hunk.filename == cursor.file {
1002 excerpt_first_line = hunk.new_start as usize;
1003 // First try to collect deletions and context (what exists before edits)
1004 for line in &hunk.lines {
1005 match line {
1006 PatchLine::Deletion(s) | PatchLine::Context(s) => {
1007 excerpt_lines.push(s.clone());
1008 }
1009 _ => {}
1010 }
1011 }
1012 // If hunk only has additions (no deletions/context), include all lines
1013 // This handles cases like adding to an empty file or section
1014 if excerpt_lines.is_empty() {
1015 for line in &hunk.lines {
1016 match line {
1017 PatchLine::Addition(s)
1018 | PatchLine::Deletion(s)
1019 | PatchLine::Context(s) => {
1020 excerpt_lines.push(s.clone());
1021 }
1022 _ => {}
1023 }
1024 }
1025 }
1026 if !excerpt_lines.is_empty() {
1027 break;
1028 }
1029 }
1030 }
1031 }
1032
1033 // Also search source patch hunks if still not found (for fallback cursor case)
1034 if excerpt_lines.is_empty() {
1035 for hunk in &src.hunks {
1036 if hunk.filename == cursor.file {
1037 excerpt_first_line = hunk.new_start as usize;
1038 for line in &hunk.lines {
1039 match line {
1040 PatchLine::Addition(s) | PatchLine::Context(s) => {
1041 excerpt_lines.push(s.clone());
1042 }
1043 _ => {}
1044 }
1045 }
1046 // If hunk only has deletions, include deletion lines
1047 if excerpt_lines.is_empty() {
1048 excerpt_first_line = hunk.old_start as usize;
1049 for line in &hunk.lines {
1050 match line {
1051 PatchLine::Deletion(s) => {
1052 excerpt_lines.push(s.clone());
1053 }
1054 _ => {}
1055 }
1056 }
1057 }
1058 if !excerpt_lines.is_empty() {
1059 break;
1060 }
1061 }
1062 }
1063 }
1064
1065 if excerpt_lines.is_empty() {
1066 return None;
1067 }
1068
1069 // Add cursor marker
1070 for (i, line) in excerpt_lines.iter_mut().enumerate() {
1071 let line_num = excerpt_first_line + i;
1072 if line_num == cursor.line {
1073 let col = cursor.column.min(line.len());
1074 // Ensure we split at a valid UTF-8 character boundary
1075 let col = if line.is_char_boundary(col) {
1076 col
1077 } else {
1078 // Find the nearest valid character boundary
1079 (0..=col)
1080 .rev()
1081 .find(|&i| line.is_char_boundary(i))
1082 .unwrap_or(0)
1083 };
1084 let (before, after) = line.split_at(col);
1085 *line = format!("{}<|user_cursor|>{}", before, after);
1086 break;
1087 }
1088 }
1089
1090 Some(excerpt_lines.join("\n"))
1091}
1092
1093#[cfg(test)]
1094mod tests {
1095 use std::path::Path;
1096
1097 use edit_prediction::example_spec::ExampleSpec;
1098
1099 use super::*;
1100
1101 #[test]
1102 fn test_tokenize() {
1103 let tokens = tokenize("hello world");
1104 assert_eq!(tokens, vec!["hello", " ", "world"]);
1105
1106 let tokens = tokenize("foo_bar123 + baz");
1107 assert_eq!(tokens, vec!["foo_", "bar123", " ", "+", " ", "baz"]);
1108
1109 let tokens = tokenize("print(\"hello\")");
1110 assert_eq!(tokens, vec!["print", "(", "\"", "hello", "\"", ")"]);
1111
1112 let tokens = tokenize("hello_world");
1113 assert_eq!(tokens, vec!["hello_", "world"]);
1114
1115 let tokens = tokenize("fn();");
1116 assert_eq!(tokens, vec!["fn", "(", ")", ";"]);
1117 }
1118
1119 #[test]
1120 fn test_fuzzy_ratio() {
1121 assert_eq!(fuzzy_ratio("hello", "hello"), 100);
1122 assert_eq!(fuzzy_ratio("", ""), 100);
1123 assert!(fuzzy_ratio("hello", "world") < 50);
1124 assert!(fuzzy_ratio("hello world", "hello worl") > 80);
1125 }
1126
1127 #[test]
1128 fn test_split_ordered_commit() {
1129 let commit = r#"// First change
1130--- a/test.rs
1131+++ b/test.rs
1132@@ -1,3 +1,4 @@
1133 fn main() {
1134+ println!("hello");
1135+ println!("world");
1136 }
1137"#;
1138 let patch = Patch::parse_unified_diff(commit);
1139 let stats = patch.stats();
1140 assert_eq!(stats.added, 2);
1141
1142 let (source, target) = split_ordered_commit(commit, 1);
1143
1144 // Source should have 1 addition
1145 let src_patch = Patch::parse_unified_diff(&source);
1146 assert_eq!(src_patch.stats().added, 1);
1147
1148 // Target should have 1 addition
1149 let tgt_patch = Patch::parse_unified_diff(&target);
1150 assert_eq!(tgt_patch.stats().added, 1);
1151 }
1152
1153 #[test]
1154 fn test_split_ordered_commit_with_deletions() {
1155 let commit = r#"// Change
1156--- a/test.rs
1157+++ b/test.rs
1158@@ -1,3 +1,3 @@
1159 fn main() {
1160- println!("old");
1161+ println!("new");
1162 }
1163"#;
1164 let patch = Patch::parse_unified_diff(commit);
1165 let stats = patch.stats();
1166 assert_eq!(stats.added, 1);
1167 assert_eq!(stats.removed, 1);
1168
1169 // Split at position 1 (after the deletion)
1170 let (source, target) = split_ordered_commit(commit, 1);
1171
1172 let src_patch = Patch::parse_unified_diff(&source);
1173 let tgt_patch = Patch::parse_unified_diff(&target);
1174
1175 // Source should have the deletion
1176 assert_eq!(src_patch.stats().removed, 1);
1177 // Target should have the addition
1178 assert_eq!(tgt_patch.stats().added, 1);
1179 }
1180
1181 #[test]
1182 fn test_generate_evaluation_example() {
1183 let commit = r#"commit abc123
1184Author: Test <test@example.com>
1185Date: Mon Jan 1 00:00:00 2024
1186
1187 Test commit
1188
1189////////////////////////////////////////////////////////////////////////////////
1190// Add greeting
1191////////////////////////////////////////////////////////////////////////////////
1192--- a/test.rs
1193+++ b/test.rs
1194@@ -1,3 +1,5 @@
1195 fn main() {
1196+ println!("hello");
1197+ println!("world");
1198 }
1199"#;
1200
1201 let result = generate_evaluation_example_from_ordered_commit(
1202 commit,
1203 "https://github.com/test/repo",
1204 "abc123",
1205 Some(SplitPoint::Fraction(0.5)),
1206 Some(42),
1207 None,
1208 );
1209
1210 assert!(result.is_ok());
1211 let case = result.unwrap();
1212 assert_eq!(case.repository_url, "https://github.com/test/repo");
1213 assert_eq!(case.revision, "abc123~1");
1214 assert!(!case.edit_history.is_empty());
1215 }
1216
1217 #[test]
1218 fn test_generate_evaluation_example_reproducible() {
1219 let commit = r#"////////////////////////////////////////////////////////////////////////////////
1220// Add greeting
1221////////////////////////////////////////////////////////////////////////////////
1222--- a/test.rs
1223+++ b/test.rs
1224@@ -1,3 +1,5 @@
1225 fn main() {
1226+ println!("hello");
1227+ println!("world");
1228 }
1229"#;
1230
1231 // Run twice with the same seed
1232 let result1 = generate_evaluation_example_from_ordered_commit(
1233 commit,
1234 "https://github.com/test/repo",
1235 "abc123",
1236 Some(SplitPoint::Fraction(0.5)),
1237 Some(12345),
1238 None,
1239 )
1240 .unwrap();
1241
1242 let result2 = generate_evaluation_example_from_ordered_commit(
1243 commit,
1244 "https://github.com/test/repo",
1245 "abc123",
1246 Some(SplitPoint::Fraction(0.5)),
1247 Some(12345),
1248 None,
1249 )
1250 .unwrap();
1251
1252 // Results should be identical
1253 assert_eq!(result1.edit_history, result2.edit_history);
1254 assert_eq!(result1.expected_patches, result2.expected_patches);
1255 assert_eq!(result1.cursor_position, result2.cursor_position);
1256 }
1257
1258 #[test]
1259 fn test_cursor_position_display() {
1260 let cursor = CursorPosition {
1261 file: "src/main.rs".to_string(),
1262 line: 42,
1263 column: 10,
1264 };
1265 assert_eq!(cursor.to_string(), "src/main.rs:42:10");
1266 }
1267
1268 #[test]
1269 fn test_imitate_human_edits_no_change_when_no_replacement() {
1270 // Source and target patches that don't form a replacement pattern
1271 let source = r#"--- a/test.rs
1272+++ b/test.rs
1273@@ -1,3 +1,4 @@
1274 fn main() {
1275+ println!("hello");
1276 }
1277"#;
1278 let target = r#"--- a/test.rs
1279+++ b/test.rs
1280@@ -1,3 +1,4 @@
1281 fn main() {
1282+ println!("world");
1283 }
1284"#;
1285
1286 let (new_src, new_tgt, cursor) = imitate_human_edits(source, target, 42);
1287
1288 // Should return unchanged when not a replacement pattern
1289 assert_eq!(new_src, source);
1290 assert_eq!(new_tgt, target);
1291 assert!(cursor.is_none());
1292 }
1293
1294 #[test]
1295 fn test_split_point_fraction() {
1296 let commit = r#"// Change
1297--- a/test.rs
1298+++ b/test.rs
1299@@ -1,5 +1,10 @@
1300 fn main() {
1301+ line1();
1302+ line2();
1303+ line3();
1304+ line4();
1305+ line5();
1306 }
1307"#;
1308
1309 // Split at 20% should give first edit in source
1310 let result = generate_evaluation_example_from_ordered_commit(
1311 commit,
1312 "",
1313 "hash",
1314 Some(SplitPoint::Fraction(0.2)),
1315 Some(1),
1316 None,
1317 );
1318
1319 assert!(result.is_ok());
1320 let case = result.unwrap();
1321
1322 // Source should have some edits
1323 let src_patch = Patch::parse_unified_diff(&case.edit_history);
1324 assert!(src_patch.stats().added > 0);
1325 }
1326
1327 #[test]
1328 fn test_split_point_index() {
1329 let commit = r#"// Change
1330--- a/test.rs
1331+++ b/test.rs
1332@@ -1,5 +1,10 @@
1333 fn main() {
1334+ line1();
1335+ line2();
1336+ line3();
1337+ line4();
1338+ line5();
1339 }
1340"#;
1341
1342 // Split at index 2 should give first 2 edits in source
1343 // With pure insertion handling, source gets 2 original + 1 partial = 3 additions
1344 let result = generate_evaluation_example_from_ordered_commit(
1345 commit,
1346 "",
1347 "hash",
1348 Some(SplitPoint::Index(2)),
1349 Some(1),
1350 None,
1351 );
1352
1353 assert!(result.is_ok());
1354 let case = result.unwrap();
1355
1356 let src_patch = Patch::parse_unified_diff(&case.edit_history);
1357 // Pure insertion adds a partial line, so we expect 3 (2 original + 1 partial)
1358 assert_eq!(src_patch.stats().added, 3);
1359 }
1360
1361 #[test]
1362 fn test_cursor_excerpt_contains_marker() {
1363 let commit = r#"////////////////////////////////////////////////////////////////////////////////
1364// Add code
1365////////////////////////////////////////////////////////////////////////////////
1366--- a/test.rs
1367+++ b/test.rs
1368@@ -1,3 +1,5 @@
1369 fn main() {
1370+ println!("hello");
1371+ println!("world");
1372 }
1373"#;
1374
1375 let result = generate_evaluation_example_from_ordered_commit(
1376 commit,
1377 "",
1378 "hash",
1379 Some(SplitPoint::Fraction(0.5)),
1380 Some(42),
1381 None,
1382 )
1383 .unwrap();
1384
1385 // Cursor excerpt should contain the cursor marker
1386 assert!(
1387 result.cursor_position.contains("<|user_cursor|>"),
1388 "Cursor excerpt should contain marker: {}",
1389 result.cursor_position
1390 );
1391 }
1392
1393 #[test]
1394 fn test_evaluation_case_json_serialization() {
1395 let case = ExampleSpec {
1396 name: "test-abc123".to_string(),
1397 repository_url: "https://github.com/test/repo".to_string(),
1398 revision: "abc123~1".to_string(),
1399 edit_history: "patch1".to_string(),
1400 cursor_path: Path::new("file.rs").into(),
1401 cursor_position: "some code<|user_cursor|>".to_string(),
1402 expected_patches: vec!["patch".to_string()],
1403 tags: vec![],
1404 reasoning: None,
1405 uncommitted_diff: String::new(),
1406 rejected_patch: None,
1407 captured_prompt_input: None,
1408 telemetry: None,
1409 human_feedback: Vec::new(),
1410 rating: None,
1411 };
1412
1413 let json = serde_json::to_string(&case).unwrap();
1414 let deserialized: ExampleSpec = serde_json::from_str(&json).unwrap();
1415
1416 assert_eq!(case.repository_url, deserialized.repository_url);
1417 assert_eq!(case.revision, deserialized.revision);
1418 assert_eq!(case.cursor_position, deserialized.cursor_position);
1419 }
1420
1421 #[test]
1422 fn test_empty_commit_returns_error() {
1423 let commit = "";
1424
1425 let result = generate_evaluation_example_from_ordered_commit(
1426 commit,
1427 "",
1428 "hash",
1429 Some(SplitPoint::Fraction(0.5)),
1430 Some(1),
1431 None,
1432 );
1433
1434 assert!(result.is_err());
1435 }
1436
1437 #[test]
1438 fn test_header_filtering() {
1439 let commit = r#"commit abc123
1440Author: Test
1441Date: Today
1442
1443 Message
1444
1445diff --git a/test.rs b/test.rs
1446index 123..456 789
1447////////////////////////////////////////////////////////////////////////////////
1448// First group
1449////////////////////////////////////////////////////////////////////////////////
1450--- a/test.rs
1451+++ b/test.rs
1452@@ -1,3 +1,4 @@
1453 fn main() {
1454+ code();
1455 }
1456"#;
1457
1458 let result = generate_evaluation_example_from_ordered_commit(
1459 commit,
1460 "",
1461 "hash",
1462 Some(SplitPoint::Index(1)),
1463 Some(1),
1464 None,
1465 );
1466
1467 assert!(result.is_ok());
1468 let case = result.unwrap();
1469
1470 // The edit history should contain the group header (// lines)
1471 // but not the commit metadata
1472 assert!(!case.edit_history.contains("Author:"));
1473 assert!(!case.edit_history.contains("Date:"));
1474 }
1475
1476 #[test]
1477 fn test_position_weight() {
1478 // High weight positions (natural pause points)
1479 assert_eq!(position_weight("foo(", 4), 10); // After '('
1480 assert_eq!(position_weight("a, b", 2), 10); // After ','
1481 assert_eq!(position_weight("x;", 2), 10); // After ';'
1482 assert_eq!(position_weight("a: b", 2), 10); // After ':'
1483 assert_eq!(position_weight("[", 1), 10); // After '['
1484 assert_eq!(position_weight("{", 1), 10); // After '{'
1485
1486 // High weight for closing brackets
1487 assert_eq!(position_weight("foo)", 4), 8); // After ')'
1488 assert_eq!(position_weight("]", 1), 8); // After ']'
1489 assert_eq!(position_weight("}", 1), 8); // After '}'
1490
1491 // High weight at end of identifier
1492 assert_eq!(position_weight("foo ", 3), 8); // End of 'foo' before space
1493 assert_eq!(position_weight("bar(", 3), 8); // End of 'bar' before '('
1494
1495 // Medium weight for operators
1496 assert_eq!(position_weight("a + b", 3), 5); // After '+'
1497 assert_eq!(position_weight("x.", 2), 5); // After '.'
1498 assert_eq!(position_weight("a=b", 2), 5); // After '='
1499
1500 // Medium weight for whitespace
1501 assert_eq!(position_weight("a ", 2), 6); // After space
1502
1503 // Low weight mid-identifier
1504 assert_eq!(position_weight("foobar", 3), 1); // Mid-identifier 'foo|bar'
1505
1506 // Edge cases
1507 assert_eq!(position_weight("", 0), 1); // Empty string
1508 assert_eq!(position_weight("a", 0), 1); // Position 0
1509 }
1510
1511 #[test]
1512 fn test_weighted_select() {
1513 // Test that weighted selection returns correct indices
1514 let weights = vec![1, 10, 1];
1515
1516 // With total weight 12, seed 0 should select index 0
1517 // seed 0 % 12 = 0, cumulative: 1 at idx 0, so returns 0
1518 assert_eq!(weighted_select(&weights, 0), 0);
1519
1520 // seed 1 % 12 = 1, cumulative: 1 at idx 0 (1 < 1 is false), 11 at idx 1 (1 < 11 is true)
1521 assert_eq!(weighted_select(&weights, 1), 1);
1522
1523 // seed 10 % 12 = 10, cumulative: 1, 11 at idx 1 (10 < 11 is true)
1524 assert_eq!(weighted_select(&weights, 10), 1);
1525
1526 // seed 11 % 12 = 11, cumulative: 1, 11 at idx 1 (11 < 11 is false), 12 at idx 2 (11 < 12 is true)
1527 assert_eq!(weighted_select(&weights, 11), 2);
1528
1529 // Empty weights should return 0
1530 let empty: Vec<u32> = vec![];
1531 assert_eq!(weighted_select(&empty, 42), 0);
1532
1533 // Single weight should always return index 0
1534 let single = vec![10];
1535 assert_eq!(weighted_select(&single, 0), 0);
1536 assert_eq!(weighted_select(&single, 100), 0);
1537 }
1538
1539 #[test]
1540 fn test_weighted_split_prefers_natural_boundaries() {
1541 // Test that with different seeds, weighted selection tends to prefer
1542 // positions after punctuation over mid-identifier positions
1543 let text_with_punctuation = "foo(bar, baz)";
1544 let text_mid_identifier = "foobar";
1545
1546 // Position after '(' should have high weight
1547 let weight_after_paren = position_weight(text_with_punctuation, 4);
1548 // Position after ',' should have high weight
1549 let weight_after_comma = position_weight(text_with_punctuation, 8);
1550 // Position mid-identifier should have low weight
1551 let weight_mid_ident = position_weight(text_mid_identifier, 3);
1552
1553 assert!(
1554 weight_after_paren > weight_mid_ident,
1555 "After '(' ({}) should be weighted higher than mid-identifier ({})",
1556 weight_after_paren,
1557 weight_mid_ident
1558 );
1559 assert!(
1560 weight_after_comma > weight_mid_ident,
1561 "After ',' ({}) should be weighted higher than mid-identifier ({})",
1562 weight_after_comma,
1563 weight_mid_ident
1564 );
1565 }
1566
1567 #[test]
1568 fn test_imitate_human_edits_pure_insertion() {
1569 // Source patch is empty (no edits yet)
1570 // Target patch has a pure insertion (adding a new line)
1571 let source = r#"--- a/test.rs
1572+++ b/test.rs
1573@@ -1,2 +1,2 @@
1574 fn main() {
1575 }
1576"#;
1577 let target = r#"--- a/test.rs
1578+++ b/test.rs
1579@@ -1,2 +1,3 @@
1580 fn main() {
1581+ println!("debug");
1582 }
1583"#;
1584
1585 let (new_src, new_tgt, cursor) = imitate_human_edits(source, target, 42);
1586
1587 // Should have transformed the patches
1588 assert_ne!(
1589 new_src, source,
1590 "Source should be modified for pure insertion"
1591 );
1592 assert_ne!(
1593 new_tgt, target,
1594 "Target should be modified for pure insertion"
1595 );
1596 assert!(cursor.is_some(), "Cursor should be set");
1597
1598 // Source should now have a partial addition
1599 let src_patch = Patch::parse_unified_diff(&new_src);
1600 assert!(
1601 src_patch.stats().added > 0,
1602 "Source should have added lines"
1603 );
1604
1605 // Target should have both a deletion (of partial) and addition (of full)
1606 let tgt_patch = Patch::parse_unified_diff(&new_tgt);
1607 assert!(
1608 tgt_patch.stats().removed > 0,
1609 "Target should have removed lines (partial)"
1610 );
1611 assert!(
1612 tgt_patch.stats().added > 0,
1613 "Target should have added lines (full)"
1614 );
1615
1616 // The cursor should be in test.rs
1617 let cursor = cursor.unwrap();
1618 assert_eq!(cursor.file, "test.rs");
1619 }
1620
1621 #[test]
1622 fn test_imitate_human_edits_pure_insertion_empty_source() {
1623 // Source patch has no hunks at all
1624 let source = "";
1625 let target = r#"--- a/test.rs
1626+++ b/test.rs
1627@@ -1,2 +1,3 @@
1628 fn main() {
1629+ println!("hello");
1630 }
1631"#;
1632
1633 let (new_src, _new_tgt, cursor) = imitate_human_edits(source, target, 123);
1634
1635 // Should have created a source patch with partial insertion
1636 assert!(!new_src.is_empty(), "Source should not be empty");
1637 assert!(cursor.is_some(), "Cursor should be set");
1638
1639 let src_patch = Patch::parse_unified_diff(&new_src);
1640 assert!(
1641 src_patch.stats().added > 0,
1642 "Source should have added lines"
1643 );
1644 }
1645
1646 #[test]
1647 fn test_imitate_human_edits_pure_insertion_intermediate_content() {
1648 // Verify the actual intermediate content is a realistic partial typing state
1649 let source = "";
1650 let target = r#"--- a/test.rs
1651+++ b/test.rs
1652@@ -1,2 +1,3 @@
1653 fn main() {
1654+ println!("hello world");
1655 }
1656"#;
1657
1658 // Test with multiple seeds to see different split points
1659 let mut found_partial = false;
1660 for seed in 1..=50 {
1661 let (new_src, new_tgt, cursor) = imitate_human_edits(source, target, seed);
1662
1663 if cursor.is_some() {
1664 let src_patch = Patch::parse_unified_diff(&new_src);
1665 let tgt_patch = Patch::parse_unified_diff(&new_tgt);
1666
1667 // Find the added line in source
1668 for hunk in &src_patch.hunks {
1669 for line in &hunk.lines {
1670 if let PatchLine::Addition(content) = line {
1671 // The partial line should be a prefix of the full line
1672 let full_line = " println!(\"hello world\");";
1673 if content != full_line && full_line.starts_with(content) {
1674 found_partial = true;
1675
1676 // Verify target has the partial as deletion
1677 let mut has_deletion = false;
1678 for tgt_hunk in &tgt_patch.hunks {
1679 for tgt_line in &tgt_hunk.lines {
1680 if let PatchLine::Deletion(del_content) = tgt_line {
1681 if del_content == content {
1682 has_deletion = true;
1683 }
1684 }
1685 }
1686 }
1687 assert!(
1688 has_deletion,
1689 "Target should have deletion of partial line"
1690 );
1691 }
1692 }
1693 }
1694 }
1695 }
1696 }
1697
1698 assert!(
1699 found_partial,
1700 "At least one seed should produce a partial intermediate state"
1701 );
1702 }
1703
1704 #[test]
1705 fn test_imitate_human_edits_inserts_after_last_source_edit() {
1706 // Regression test: intermediate content should appear after the last edit
1707 // in the source patch, not at the position of the first target edit.
1708 // This ensures the diff output correctly imitates human typing order.
1709 //
1710 // The bug was: when source has edits and target has a pure insertion,
1711 // the intermediate content was inserted at tgt_edit_loc.line_index_within_hunk
1712 // (position of first target edit) instead of after the last source edit.
1713 //
1714 // Source patch has edits at lines 1-4, target has a new edit at line 10
1715 // (different location to avoid the "same line" early return)
1716 let source = r#"--- a/test.py
1717+++ b/test.py
1718@@ -1,4 +1,5 @@
1719+import foo
1720 import bar
1721-import old
1722 import baz
1723+import qux
1724"#;
1725 // Target has a pure insertion at a different line (line 10, not overlapping with source)
1726 let target = r#"--- a/test.py
1727+++ b/test.py
1728@@ -10,3 +10,4 @@
1729 def main():
1730+ print("hello world")
1731 pass
1732"#;
1733
1734 // Use a seed that produces a partial result
1735 let (new_src, _new_tgt, cursor) = imitate_human_edits(source, target, 42);
1736
1737 // The function should produce a modified patch
1738 assert!(cursor.is_some(), "Should produce intermediate state");
1739
1740 let src_patch = Patch::parse_unified_diff(&new_src);
1741 let all_additions: Vec<_> = src_patch
1742 .hunks
1743 .iter()
1744 .flat_map(|h| h.lines.iter())
1745 .filter_map(|l| match l {
1746 PatchLine::Addition(s) => Some(s.as_str()),
1747 _ => None,
1748 })
1749 .collect();
1750
1751 // The intermediate content (partial 'print("hello world")') should be
1752 // the LAST addition, appearing after "+import qux" (the last source edit)
1753 let last_addition = all_additions.last().expect("Should have additions");
1754 assert!(
1755 last_addition.trim_start().starts_with("pr"),
1756 "Intermediate content should be the last addition (partial 'print'), but last was: {:?}",
1757 last_addition
1758 );
1759
1760 // Verify the original source edits are still in order before the intermediate
1761 let foo_pos = all_additions.iter().position(|s| *s == "import foo");
1762 let qux_pos = all_additions.iter().position(|s| *s == "import qux");
1763 let intermediate_pos = all_additions
1764 .iter()
1765 .position(|s| s.trim_start().starts_with("pr"));
1766
1767 assert!(foo_pos.is_some(), "Should have 'import foo'");
1768 assert!(qux_pos.is_some(), "Should have 'import qux'");
1769 assert!(
1770 intermediate_pos.is_some(),
1771 "Should have intermediate content"
1772 );
1773
1774 assert!(
1775 foo_pos < qux_pos && qux_pos < intermediate_pos,
1776 "Order should be: foo < qux < intermediate. Got foo={:?}, qux={:?}, intermediate={:?}",
1777 foo_pos,
1778 qux_pos,
1779 intermediate_pos
1780 );
1781 }
1782
1783 #[test]
1784 fn test_cursor_excerpt_with_multibyte_utf8() {
1785 // Test that cursor excerpt handles multi-byte UTF-8 characters correctly
1786 // The Chinese character '第' is 3 bytes (0..3)
1787 let cursor = CursorPosition {
1788 file: "test.md".to_string(),
1789 line: 1,
1790 column: 1, // Byte index 1 is inside '第' (bytes 0..3)
1791 };
1792
1793 let source_patch = r#"--- a/test.md
1794+++ b/test.md
1795@@ -1,1 +1,1 @@
1796+第 14 章 Flask 工作原理与机制解析**
1797"#;
1798
1799 let target_patch = "";
1800
1801 // This should not panic even though column=1 is not a char boundary
1802 let result = get_cursor_excerpt(&cursor, source_patch, target_patch);
1803
1804 // The function should handle the invalid byte index gracefully
1805 if let Some(excerpt) = result {
1806 assert!(
1807 excerpt.contains("<|user_cursor|>"),
1808 "Cursor excerpt should contain marker"
1809 );
1810 // The marker should be placed at a valid character boundary
1811 // (either at the start or after '第')
1812 }
1813 }
1814}