1//! `ep split-commit` implementation.
2//!
3//! This command generates a single evaluation example JSON object from a
4//! chronologically-ordered unified diff (a "commit").
5//!
6//! TODO: Port Python code to generate chronologically-ordered commits
7use crate::reorder_patch::{Patch, PatchLine, extract_edits, locate_edited_line};
8
9/// Find the largest valid UTF-8 char boundary at or before `index` in `s`.
10fn floor_char_boundary(s: &str, index: usize) -> usize {
11 if index >= s.len() {
12 s.len()
13 } else if s.is_char_boundary(index) {
14 index
15 } else {
16 // Find the nearest valid character boundary at or before index
17 (0..index)
18 .rev()
19 .find(|&i| s.is_char_boundary(i))
20 .unwrap_or(0)
21 }
22}
23use anyhow::{Context as _, Result};
24use clap::Args;
25use edit_prediction::example_spec::ExampleSpec;
26use rand::Rng;
27use rand::SeedableRng;
28use serde::{Deserialize, Serialize};
29use similar::{DiffTag, TextDiff};
30use std::collections::BTreeSet;
31use std::fs;
32use std::io::{self, Write};
33use std::path::Path;
34use std::path::PathBuf;
35
36/// `ep split-commit` CLI args.
37#[derive(Debug, Args, Clone)]
38pub struct SplitCommitArgs {
39 /// Split point (float 0.0-1.0 for fraction, or integer for index)
40 #[arg(long, short = 's')]
41 pub split_point: Option<String>,
42
43 /// Random seed for reproducibility
44 #[arg(long)]
45 pub seed: Option<u64>,
46
47 /// Pretty-print JSON output
48 #[arg(long, short = 'p')]
49 pub pretty: bool,
50
51 /// Number of samples to generate per commit (samples random split points)
52 #[arg(long, short = 'n')]
53 pub num_samples: Option<usize>,
54}
55
56/// Input format for annotated commits (JSON Lines).
57#[derive(Debug, Clone, Deserialize)]
58#[allow(dead_code)]
59pub struct AnnotatedCommit {
60 /// Repository path (e.g., "repos/zed")
61 pub repo: String,
62 /// Repository URL (e.g., "https://github.com/zed-industries/zed")
63 pub repo_url: String,
64 /// Commit SHA
65 pub commit_sha: String,
66 /// Chronologically reordered commit diff
67 pub reordered_commit: String,
68 /// Original commit diff
69 pub original_commit: String,
70 /// Whether diff stats match between original and reordered
71 pub diff_stats_match: bool,
72}
73
74/// Cursor position in a file.
75#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
76pub struct CursorPosition {
77 pub file: String,
78 pub line: usize,
79 pub column: usize,
80}
81
82impl std::fmt::Display for CursorPosition {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 write!(f, "{}:{}:{}", self.file, self.line, self.column)
85 }
86}
87
88/// Represents a split commit with source and target patches.
89#[derive(Debug, Clone)]
90pub struct SplitCommit {
91 pub source_patch: String,
92 pub target_patch: String,
93}
94
95/// Split point specification for evaluation generation.
96#[derive(Debug, Clone)]
97pub enum SplitPoint {
98 /// Fraction of total edits (0.0 to 1.0)
99 Fraction(f64),
100 /// Absolute index
101 Index(usize),
102}
103
104fn parse_split_point(value: &str) -> Option<SplitPoint> {
105 if value.contains('.') {
106 value.parse::<f64>().ok().map(SplitPoint::Fraction)
107 } else {
108 value.parse::<usize>().ok().map(SplitPoint::Index)
109 }
110}
111
112/// Entry point for the `ep split-commit` subcommand.
113///
114/// This runs synchronously and outputs JSON Lines (one output per input line).
115pub fn run_split_commit(
116 args: &SplitCommitArgs,
117 inputs: &[PathBuf],
118 output_path: Option<&PathBuf>,
119) -> Result<()> {
120 use std::collections::HashSet;
121 use std::io::BufRead;
122
123 let stdin_path = PathBuf::from("-");
124 let inputs = if inputs.is_empty() {
125 std::slice::from_ref(&stdin_path)
126 } else {
127 inputs
128 };
129
130 let split_point = args.split_point.as_deref().and_then(parse_split_point);
131 let mut output_lines = Vec::new();
132
133 for input_path in inputs {
134 let input: Box<dyn BufRead> = if input_path.as_os_str() == "-" {
135 Box::new(io::BufReader::new(io::stdin()))
136 } else {
137 let file = fs::File::open(input_path)
138 .with_context(|| format!("failed to open input file {}", input_path.display()))?;
139 Box::new(io::BufReader::new(file))
140 };
141
142 for (line_num, line_result) in input.lines().enumerate() {
143 let line =
144 line_result.with_context(|| format!("failed to read line {}", line_num + 1))?;
145
146 if line.trim().is_empty() {
147 continue;
148 }
149
150 let annotated: AnnotatedCommit = serde_json::from_str(&line)
151 .with_context(|| format!("failed to parse JSON at line {}", line_num + 1))?;
152
153 // Generate multiple samples if num_samples is set
154 if let Some(num_samples) = args.num_samples {
155 let mut seen_samples: HashSet<String> = HashSet::new();
156 let base_seed = args.seed.unwrap_or_else(|| rand::random());
157
158 for sample_idx in 0..num_samples {
159 let sample_seed = base_seed.wrapping_add(sample_idx as u64);
160
161 let case = generate_evaluation_example_from_ordered_commit(
162 &annotated.reordered_commit,
163 &annotated.repo_url,
164 &annotated.commit_sha,
165 None, // Use random split point for multi-sample mode
166 Some(sample_seed),
167 Some(sample_idx),
168 )
169 .with_context(|| {
170 format!(
171 "failed to generate evaluation example for commit {} at line {} (sample {})",
172 annotated.commit_sha,
173 line_num + 1,
174 sample_idx
175 )
176 })?;
177
178 let json = if args.pretty {
179 serde_json::to_string_pretty(&case)
180 } else {
181 serde_json::to_string(&case)
182 }
183 .context("failed to serialize evaluation case as JSON")?;
184
185 // Only add unique samples (different split points may produce same result)
186 if seen_samples.insert(json.clone()) {
187 output_lines.push(json);
188 }
189 }
190 } else {
191 let case = generate_evaluation_example_from_ordered_commit(
192 &annotated.reordered_commit,
193 &annotated.repo_url,
194 &annotated.commit_sha,
195 split_point.clone(),
196 args.seed,
197 None,
198 )
199 .with_context(|| {
200 format!(
201 "failed to generate evaluation example for commit {} at line {}",
202 annotated.commit_sha,
203 line_num + 1
204 )
205 })?;
206
207 let json = if args.pretty {
208 serde_json::to_string_pretty(&case)
209 } else {
210 serde_json::to_string(&case)
211 }
212 .context("failed to serialize evaluation case as JSON")?;
213
214 output_lines.push(json);
215 }
216 }
217 }
218
219 let output_content = output_lines.join("\n") + if output_lines.is_empty() { "" } else { "\n" };
220
221 if let Some(path) = output_path {
222 fs::write(path, &output_content)
223 .with_context(|| format!("failed to write output to {}", path.display()))?;
224 } else {
225 io::stdout()
226 .write_all(output_content.as_bytes())
227 .context("failed to write to stdout")?;
228 }
229
230 Ok(())
231}
232
233/// Main function to generate an evaluation example from an ordered commit.
234///
235/// # Arguments
236/// * `commit` - Chronologically ordered unified diff of the commit
237/// * `repository_url` - URL of the repository
238/// * `commit_hash` - Hash of the commit
239/// * `split_point` - Point at which the commit will be split (None for random)
240/// * `seed` - Optional seed for randomness
241/// * `sample_num` - Optional sample number for generating unique names
242pub fn generate_evaluation_example_from_ordered_commit(
243 commit: &str,
244 repository_url: &str,
245 commit_hash: &str,
246 split_point: Option<SplitPoint>,
247 seed: Option<u64>,
248 sample_num: Option<usize>,
249) -> Result<ExampleSpec> {
250 let mut rng: Box<dyn rand::RngCore> = match seed {
251 Some(seed) => Box::new(rand::rngs::StdRng::seed_from_u64(seed)),
252 None => Box::new(rand::rngs::ThreadRng::default()),
253 };
254
255 // Parse and normalize the commit
256 let mut patch = Patch::parse_unified_diff(commit);
257
258 // Filter header to only keep lines starting with "//"
259 let header_lines: Vec<&str> = patch
260 .header
261 .lines()
262 .filter(|line| line.starts_with("//"))
263 .collect();
264 patch.header = if header_lines.is_empty() {
265 String::new()
266 } else {
267 header_lines.join("\n") + "\n"
268 };
269 let commit_normalized = patch.to_string();
270
271 // Compute the split point
272 let stats = patch.stats();
273 let num_edits = stats.added + stats.removed;
274
275 anyhow::ensure!(num_edits != 0, "no edits found in commit");
276
277 let split = match split_point {
278 None => rng.random_range(1..=num_edits),
279 Some(SplitPoint::Fraction(f)) => {
280 let v = (f * num_edits as f64).floor() as usize;
281 v.min(num_edits)
282 }
283 Some(SplitPoint::Index(i)) => i.min(num_edits),
284 };
285
286 // Split the commit into source and target patches
287 let (prefix, suffix) = split_ordered_commit(&commit_normalized, split);
288
289 let mut split_commit = SplitCommit {
290 source_patch: prefix,
291 target_patch: suffix,
292 };
293
294 // Imitate human edits
295 let human_edit_seed = rng.random_range(1..=10000u64);
296 let (src_patch, tgt_patch, cursor_opt) = imitate_human_edits(
297 &split_commit.source_patch,
298 &split_commit.target_patch,
299 human_edit_seed,
300 );
301 split_commit.source_patch = src_patch;
302 split_commit.target_patch = tgt_patch;
303
304 // Sample cursor position
305 let cursor = match cursor_opt {
306 Some(c) => c,
307 None => sample_cursor_position(&patch, &split_commit)
308 .context("failed to sample cursor position")?,
309 };
310
311 // Get cursor excerpt
312 let cursor_excerpt = get_cursor_excerpt(
313 &cursor,
314 &split_commit.source_patch,
315 &split_commit.target_patch,
316 )
317 .context("failed to generate cursor excerpt")?;
318
319 // Handle edge case where split_point == 0
320 if split == 0 {
321 split_commit.target_patch = String::new();
322 }
323
324 let repo_name = repository_url
325 .trim_end_matches('/')
326 .rsplit('/')
327 .next()
328 .unwrap_or("unknown");
329 let short_sha = &commit_hash[..commit_hash.len().min(8)];
330 let name = match sample_num {
331 Some(n) => format!("{}-{}-{}", repo_name, short_sha, n),
332 None => format!("{}-{}", repo_name, short_sha),
333 };
334
335 Ok(ExampleSpec {
336 name,
337 repository_url: repository_url.to_string(),
338 revision: format!("{}~1", commit_hash),
339 edit_history: split_commit.source_patch.clone(),
340 // cursor_position: cursor.to_string(),
341 cursor_path: Path::new(&cursor.file).into(),
342 cursor_position: cursor_excerpt,
343 expected_patches: vec![split_commit.target_patch],
344 tags: vec![],
345 reasoning: None,
346 uncommitted_diff: String::new(),
347 rejected_patch: None,
348 })
349}
350
351/// Split an ordered commit into source and target commits.
352///
353/// # Arguments
354/// * `commit` - Ordered commit string
355/// * `split_pos` - Position to split the commit (number of edited lines)
356///
357/// # Returns
358/// A tuple of (source_diff, target_diff)
359pub fn split_ordered_commit(commit: &str, split_pos: usize) -> (String, String) {
360 let patch = Patch::parse_unified_diff(commit);
361 let source_edits: BTreeSet<usize> = (0..split_pos).collect();
362 let (source, target) = extract_edits(&patch, &source_edits);
363
364 let mut source_str = source.to_string();
365 let target_str = target.to_string();
366
367 // Strip last group header from the source (lines starting with "//" at the end)
368 let source_lines: Vec<&str> = source_str.lines().collect();
369 let mut end_idx = source_lines.len();
370 for i in (0..source_lines.len()).rev() {
371 if source_lines[i].starts_with("//") {
372 end_idx = i;
373 } else {
374 break;
375 }
376 }
377 if end_idx < source_lines.len() {
378 source_str = source_lines[..end_idx].join("\n");
379 if !source_str.is_empty() {
380 source_str.push('\n');
381 }
382 }
383
384 (source_str, target_str)
385}
386
387/// Tokenize text into words and non-word characters.
388fn tokenize(text: &str) -> Vec<String> {
389 let mut tokens = Vec::new();
390 let mut current = String::new();
391
392 for ch in text.chars() {
393 if ch.is_alphanumeric() {
394 current.push(ch);
395 } else if ch == '_' {
396 // Include underscore with the current word, then flush
397 current.push(ch);
398 if !current.is_empty() {
399 tokens.push(std::mem::take(&mut current));
400 }
401 } else {
402 // Punctuation or whitespace - flush current word first
403 if !current.is_empty() {
404 tokens.push(std::mem::take(&mut current));
405 }
406 // Each punctuation/whitespace is its own token
407 tokens.push(ch.to_string());
408 }
409 }
410
411 if !current.is_empty() {
412 tokens.push(current);
413 }
414
415 tokens
416}
417
418/// Calculate the weight for a split position based on the character at that position.
419///
420/// Higher weights indicate more natural pause points (e.g., after punctuation,
421/// at identifier boundaries). Lower weights indicate less natural points
422/// (e.g., mid-identifier).
423fn position_weight(text: &str, pos: usize) -> u32 {
424 if pos == 0 || pos > text.len() {
425 return 1;
426 }
427
428 let chars: Vec<char> = text.chars().collect();
429 if pos > chars.len() {
430 return 1;
431 }
432
433 // Get the character just before this position (what we just "typed")
434 let prev_char = chars[pos - 1];
435
436 // High weight: natural pause points (end of statement/argument, opening brackets)
437 if matches!(prev_char, ',' | ';' | ':' | '(' | '[' | '{') {
438 return 10;
439 }
440
441 // High weight: closing brackets (finished a group)
442 if matches!(prev_char, ')' | ']' | '}') {
443 return 8;
444 }
445
446 // Medium weight: operators and method chains
447 if matches!(
448 prev_char,
449 '.' | '+' | '-' | '*' | '/' | '=' | '<' | '>' | '&' | '|' | '!'
450 ) {
451 return 5;
452 }
453
454 // Check if we're at the end of an identifier (word char followed by non-word char)
455 let is_prev_word_char = prev_char.is_alphanumeric() || prev_char == '_';
456 let is_next_word_char =
457 pos < chars.len() && (chars[pos].is_alphanumeric() || chars[pos] == '_');
458
459 if is_prev_word_char && !is_next_word_char {
460 // End of identifier - high weight
461 return 8;
462 }
463
464 // Whitespace is a natural pause
465 if prev_char.is_whitespace() {
466 return 6;
467 }
468
469 // Mid-identifier: low weight (rare autocomplete scenarios)
470 if is_prev_word_char && is_next_word_char {
471 return 1;
472 }
473
474 // Default medium-low weight
475 3
476}
477
478/// Select a weighted random index from a list of weights.
479///
480/// Returns an index based on the weights, using the provided seed for
481/// deterministic selection.
482fn weighted_select(weights: &[u32], seed: u64) -> usize {
483 if weights.is_empty() {
484 return 0;
485 }
486
487 let total_weight: u64 = weights.iter().map(|&w| w as u64).sum();
488 if total_weight == 0 {
489 // Fallback to uniform selection if all weights are zero
490 return seed as usize % weights.len();
491 }
492
493 // Use seed to select a value in [0, total_weight)
494 let target = seed % total_weight;
495 let mut cumulative: u64 = 0;
496
497 for (idx, &weight) in weights.iter().enumerate() {
498 cumulative += weight as u64;
499 if target < cumulative {
500 return idx;
501 }
502 }
503
504 // Fallback to last index
505 weights.len() - 1
506}
507
508/// Calculate similarity ratio between two strings (0-100).
509fn fuzzy_ratio(s1: &str, s2: &str) -> u32 {
510 if s1.is_empty() && s2.is_empty() {
511 return 100;
512 }
513 if s1.is_empty() || s2.is_empty() {
514 return 0;
515 }
516
517 let diff = TextDiff::from_chars(s1, s2);
518 let matching: usize = diff
519 .ops()
520 .iter()
521 .filter_map(|op| {
522 if matches!(op.tag(), DiffTag::Equal) {
523 Some(op.new_range().len())
524 } else {
525 None
526 }
527 })
528 .sum();
529
530 let total = s1.len() + s2.len();
531 ((2 * matching * 100) / total) as u32
532}
533
534/// Imitate human edits by introducing partial line edits.
535///
536/// This function simulates how a human might incrementally type code,
537/// rather than making complete line replacements.
538pub fn imitate_human_edits(
539 source_patch: &str,
540 target_patch: &str,
541 seed: u64,
542) -> (String, String, Option<CursorPosition>) {
543 let no_change = (source_patch.to_string(), target_patch.to_string(), None);
544
545 let src_patch = Patch::parse_unified_diff(source_patch);
546 let tgt_patch = Patch::parse_unified_diff(target_patch);
547
548 if tgt_patch.hunks.is_empty() {
549 return no_change;
550 }
551
552 // Try to locate the first edit in target
553 let tgt_edit_loc = match locate_edited_line(&tgt_patch, 0) {
554 Some(loc) => loc,
555 None => return no_change,
556 };
557
558 let tgt_is_addition = matches!(tgt_edit_loc.patch_line, PatchLine::Addition(_));
559 if !tgt_is_addition {
560 return no_change;
561 }
562
563 let tgt_line = match &tgt_edit_loc.patch_line {
564 PatchLine::Addition(s) => s.clone(),
565 _ => return no_change,
566 };
567
568 // Try to locate the last edit in source
569 let src_edit_loc = locate_edited_line(&src_patch, -1);
570
571 // Check if source has ANY edit at the same line as target's first edit
572 // We need to iterate through all edits to check this
573 let src_has_edit_at_target_line = {
574 let mut found = false;
575 let mut idx = 0isize;
576 while let Some(loc) = locate_edited_line(&src_patch, idx) {
577 if loc.filename == tgt_edit_loc.filename
578 && loc.target_line_number == tgt_edit_loc.target_line_number
579 {
580 found = true;
581 break;
582 }
583 idx += 1;
584 }
585 found
586 };
587
588 // Check if this is a replacement (deletion followed by insertion on the same line)
589 // or a pure insertion (no corresponding deletion in source)
590 let is_replacement = src_edit_loc.as_ref().map_or(false, |loc| {
591 matches!(loc.patch_line, PatchLine::Deletion(_))
592 && loc.filename == tgt_edit_loc.filename
593 && loc.target_line_number == tgt_edit_loc.target_line_number
594 });
595
596 // If source has an edit at the same line but it's not a replacement (i.e., it's an addition),
597 // we shouldn't process this as a pure insertion either
598 if !is_replacement && src_has_edit_at_target_line {
599 return no_change;
600 }
601
602 let src_line = if is_replacement {
603 match &src_edit_loc.as_ref().unwrap().patch_line {
604 PatchLine::Deletion(s) => s.clone(),
605 _ => return no_change,
606 }
607 } else {
608 // Pure insertion: source line is empty
609 String::new()
610 };
611
612 // Don't process if source and target are the same
613 if src_line == tgt_line {
614 return no_change;
615 }
616
617 // Tokenize both lines
618 let src_tokens = tokenize(&src_line);
619 let tgt_tokens = tokenize(&tgt_line);
620
621 // Convert to slices for similar
622 let src_refs: Vec<&str> = src_tokens.iter().map(|s| s.as_str()).collect();
623 let tgt_refs: Vec<&str> = tgt_tokens.iter().map(|s| s.as_str()).collect();
624
625 // Use similar to get diff operations
626 let diff = TextDiff::from_slices(&src_refs, &tgt_refs);
627
628 // Build weights for each possible split position
629 let mut position_weights: Vec<u32> = Vec::new();
630
631 // Simulate the edit process to collect weights for all possible split positions
632 {
633 let mut current_text = String::new();
634
635 for op in diff.ops() {
636 match op.tag() {
637 DiffTag::Equal => {
638 for i in op.old_range() {
639 current_text.push_str(&src_tokens[i]);
640 }
641 }
642 DiffTag::Replace => {
643 let ins: String = op.new_range().map(|i| tgt_tokens[i].as_str()).collect();
644 let del: String = op.old_range().map(|i| src_tokens[i].as_str()).collect();
645
646 // For insertion part
647 for ch in ins.chars() {
648 current_text.push(ch);
649 let weight = position_weight(¤t_text, current_text.len());
650 position_weights.push(weight);
651 }
652
653 // For deletion part (we're "untyping" from source)
654 for _ in del.chars() {
655 // Weight deletions lower as they represent removing text
656 position_weights.push(2);
657 }
658 }
659 DiffTag::Insert => {
660 let ins: String = op.new_range().map(|i| tgt_tokens[i].as_str()).collect();
661 for ch in ins.chars() {
662 current_text.push(ch);
663 let weight = position_weight(¤t_text, current_text.len());
664 position_weights.push(weight);
665 }
666 }
667 DiffTag::Delete => {
668 let del: String = op.old_range().map(|i| src_tokens[i].as_str()).collect();
669 for _ in del.chars() {
670 // Weight deletions lower
671 position_weights.push(2);
672 }
673 }
674 }
675 }
676 }
677
678 // Use weighted selection to choose split index
679 if position_weights.is_empty() {
680 return no_change;
681 }
682 let split_index = weighted_select(&position_weights, seed);
683
684 let mut edit_index = 0usize;
685 let mut new_src = String::new();
686 let mut split_found = false;
687 let mut last_old_end = 0usize;
688
689 for op in diff.ops() {
690 match op.tag() {
691 DiffTag::Equal => {
692 for i in op.old_range() {
693 new_src.push_str(&src_tokens[i]);
694 }
695 last_old_end = op.old_range().end;
696 }
697 DiffTag::Replace => {
698 // Handle replace as delete + insert
699 let del: String = op.old_range().map(|i| src_tokens[i].as_str()).collect();
700 let ins: String = op.new_range().map(|i| tgt_tokens[i].as_str()).collect();
701 let repl_len = del.len() + ins.len();
702 if edit_index + repl_len >= split_index {
703 // Split within this replace operation
704 let offset = split_index - edit_index;
705 if offset < ins.len() {
706 let safe_offset = floor_char_boundary(&ins, offset);
707 new_src.push_str(&ins[..safe_offset]);
708 } else {
709 new_src.push_str(&ins);
710 let del_offset = offset - ins.len();
711 let safe_del_offset = floor_char_boundary(&del, del_offset.min(del.len()));
712 new_src.push_str(&del[..safe_del_offset]);
713 }
714 split_found = true;
715 last_old_end = op.old_range().end;
716 break;
717 } else {
718 edit_index += repl_len;
719 new_src.push_str(&ins);
720 last_old_end = op.old_range().end;
721 }
722 }
723 DiffTag::Insert => {
724 let repl: String = op.new_range().map(|i| tgt_tokens[i].as_str()).collect();
725 if edit_index + repl.len() >= split_index {
726 let offset = split_index - edit_index;
727 let safe_offset = floor_char_boundary(&repl, offset);
728 new_src.push_str(&repl[..safe_offset]);
729 split_found = true;
730 break;
731 } else {
732 edit_index += repl.len();
733 new_src.push_str(&repl);
734 }
735 }
736 DiffTag::Delete => {
737 let repl: String = op.old_range().map(|i| src_tokens[i].as_str()).collect();
738 if edit_index + repl.len() >= split_index {
739 let offset = split_index - edit_index;
740 let safe_offset = floor_char_boundary(&repl, offset);
741 new_src.push_str(&repl[..safe_offset]);
742 split_found = true;
743 last_old_end = op.old_range().start + safe_offset.min(op.old_range().len());
744 break;
745 } else {
746 edit_index += repl.len();
747 new_src.push_str(&repl);
748 last_old_end = op.old_range().end;
749 }
750 }
751 }
752 }
753
754 if !split_found {
755 return no_change;
756 }
757
758 // Calculate cursor position
759 let cursor = CursorPosition {
760 file: tgt_edit_loc.filename.clone(),
761 line: if is_replacement {
762 src_edit_loc.as_ref().unwrap().source_line_number
763 } else {
764 tgt_edit_loc.target_line_number
765 },
766 column: new_src.len() + 1,
767 };
768
769 // Add remainder of source if similar enough to target remainder
770 let remainder_src: String = (last_old_end..src_tokens.len())
771 .map(|i| src_tokens[i].as_str())
772 .collect();
773 let remainder_tgt: String = (last_old_end..tgt_tokens.len())
774 .filter_map(|i| tgt_tokens.get(i).map(|s| s.as_str()))
775 .collect();
776
777 let ratio = fuzzy_ratio(&remainder_src, &remainder_tgt);
778 if ratio > 35 {
779 new_src.push_str(&remainder_src);
780 }
781
782 if new_src.trim().is_empty() {
783 return no_change;
784 }
785
786 if new_src == src_line {
787 return no_change;
788 }
789
790 // Build new source patch with the intermediate line
791 let mut new_src_patch = src_patch;
792 if is_replacement {
793 // For replacements, insert after the deletion line
794 let src_loc = src_edit_loc.as_ref().unwrap();
795 if let Some(hunk) = new_src_patch.hunks.get_mut(src_loc.hunk_index) {
796 hunk.lines.insert(
797 src_loc.line_index_within_hunk + 1,
798 PatchLine::Addition(new_src.clone()),
799 );
800 hunk.new_count += 1;
801 }
802 } else {
803 // For pure insertions, insert after the last edit in source patch
804 // This imitates human typing - the intermediate content is what the user is currently typing
805 let last_src_edit = locate_edited_line(&new_src_patch, -1);
806
807 if let Some(src_loc) = last_src_edit {
808 // Insert after the last edit in source
809 if let Some(hunk) = new_src_patch.hunks.get_mut(src_loc.hunk_index) {
810 hunk.lines.insert(
811 src_loc.line_index_within_hunk + 1,
812 PatchLine::Addition(new_src.clone()),
813 );
814 hunk.new_count += 1;
815 }
816 } else {
817 // Source patch is empty or has incompatible hunk structure, create a new hunk based on target
818 if let Some(tgt_hunk) = tgt_patch.hunks.get(tgt_edit_loc.hunk_index) {
819 let mut new_hunk = tgt_hunk.clone();
820 // Replace the full addition with the partial one
821 new_hunk.lines.clear();
822 for (i, line) in tgt_hunk.lines.iter().enumerate() {
823 if i == tgt_edit_loc.line_index_within_hunk {
824 new_hunk.lines.push(PatchLine::Addition(new_src.clone()));
825 } else {
826 match line {
827 PatchLine::Addition(_) => {
828 // Skip other additions from target
829 }
830 _ => new_hunk.lines.push(line.clone()),
831 }
832 }
833 }
834 new_hunk.new_count = new_hunk.old_count + 1;
835 new_src_patch.hunks.push(new_hunk);
836 // Copy header from target if source doesn't have one
837 if new_src_patch.header.is_empty() {
838 new_src_patch.header = tgt_patch.header.clone();
839 }
840 }
841 }
842 }
843
844 // Build new target patch with the intermediate line as deletion
845 let mut new_tgt_patch = tgt_patch;
846 if let Some(hunk) = new_tgt_patch.hunks.get_mut(tgt_edit_loc.hunk_index) {
847 hunk.lines.insert(
848 tgt_edit_loc.line_index_within_hunk,
849 PatchLine::Deletion(new_src),
850 );
851 hunk.old_count += 1;
852 }
853
854 (
855 new_src_patch.to_string(),
856 new_tgt_patch.to_string(),
857 Some(cursor),
858 )
859}
860
861/// Locate the end of the last edit in a patch.
862fn locate_end_of_last_edit(patch: &Patch) -> Option<CursorPosition> {
863 let loc = locate_edited_line(patch, -1)?;
864
865 let (line, col) = match &loc.patch_line {
866 PatchLine::Addition(content) => (loc.target_line_number, content.len()),
867 PatchLine::Deletion(_) => (loc.target_line_number, 1),
868 _ => return None,
869 };
870
871 Some(CursorPosition {
872 file: loc.filename,
873 line,
874 column: col,
875 })
876}
877
878/// Locate the beginning of the first edit in a patch.
879fn locate_beginning_of_first_edit(patch: &Patch) -> Option<CursorPosition> {
880 let loc = locate_edited_line(patch, 0)?;
881
882 let hunk = patch.hunks.get(loc.hunk_index)?;
883 let column = if loc.line_index_within_hunk > 0 {
884 if let Some(prev_line) = hunk.lines.get(loc.line_index_within_hunk - 1) {
885 let content = match prev_line {
886 PatchLine::Context(s) | PatchLine::Addition(s) | PatchLine::Deletion(s) => s,
887 _ => return None,
888 };
889 content.len().max(1) - 1
890 } else {
891 0
892 }
893 } else {
894 0
895 };
896
897 let line = loc.target_line_number.saturating_sub(1).max(1);
898
899 Some(CursorPosition {
900 file: loc.filename,
901 line,
902 column,
903 })
904}
905
906/// Sample cursor position according to the following rules:
907/// 1. 50% chance of cursor being at the end of the source patch
908/// 2. 50% chance of cursor being at the beginning of the target patch
909pub fn sample_cursor_position(patch: &Patch, split_commit: &SplitCommit) -> Option<CursorPosition> {
910 // Try end of history first
911 let src_patch = Patch::parse_unified_diff(&split_commit.source_patch);
912 if let Some(cursor) = locate_end_of_last_edit(&src_patch) {
913 return Some(cursor);
914 }
915
916 // Try beginning of target
917 let tgt_patch = Patch::parse_unified_diff(&split_commit.target_patch);
918 if let Some(cursor) = locate_beginning_of_first_edit(&tgt_patch) {
919 return Some(cursor);
920 }
921
922 // Fallback: use the original patch
923 locate_end_of_last_edit(patch)
924}
925
926/// Get cursor excerpt from the patches.
927///
928/// This extracts the lines around the cursor position with a cursor marker.
929pub fn get_cursor_excerpt(
930 cursor: &CursorPosition,
931 source_patch: &str,
932 target_patch: &str,
933) -> Option<String> {
934 let mut excerpt_lines: Vec<String> = Vec::new();
935 let mut excerpt_first_line: usize = 0;
936
937 // Search in the last hunk of source patch
938 let src = Patch::parse_unified_diff(source_patch);
939 if let Some(loc) = locate_edited_line(&src, -1) {
940 if loc.filename == cursor.file && loc.target_line_number == cursor.line {
941 if let Some(hunk) = src.hunks.get(loc.hunk_index) {
942 excerpt_first_line = hunk.new_start as usize;
943 for line in &hunk.lines {
944 match line {
945 PatchLine::Addition(s) | PatchLine::Context(s) => {
946 excerpt_lines.push(s.clone());
947 }
948 _ => {}
949 }
950 }
951 // If hunk only has deletions (file deletion), include deletion lines
952 if excerpt_lines.is_empty() {
953 excerpt_first_line = hunk.old_start as usize;
954 for line in &hunk.lines {
955 match line {
956 PatchLine::Deletion(s) => {
957 excerpt_lines.push(s.clone());
958 }
959 _ => {}
960 }
961 }
962 }
963 }
964 }
965 }
966
967 // Search in target patch if not found
968 if excerpt_lines.is_empty() {
969 let tgt = Patch::parse_unified_diff(target_patch);
970 // Search all hunks for the cursor file, not just the first edit's hunk
971 for hunk in &tgt.hunks {
972 if hunk.filename == cursor.file {
973 excerpt_first_line = hunk.new_start as usize;
974 // First try to collect deletions and context (what exists before edits)
975 for line in &hunk.lines {
976 match line {
977 PatchLine::Deletion(s) | PatchLine::Context(s) => {
978 excerpt_lines.push(s.clone());
979 }
980 _ => {}
981 }
982 }
983 // If hunk only has additions (no deletions/context), include all lines
984 // This handles cases like adding to an empty file or section
985 if excerpt_lines.is_empty() {
986 for line in &hunk.lines {
987 match line {
988 PatchLine::Addition(s)
989 | PatchLine::Deletion(s)
990 | PatchLine::Context(s) => {
991 excerpt_lines.push(s.clone());
992 }
993 _ => {}
994 }
995 }
996 }
997 if !excerpt_lines.is_empty() {
998 break;
999 }
1000 }
1001 }
1002 }
1003
1004 // Also search source patch hunks if still not found (for fallback cursor case)
1005 if excerpt_lines.is_empty() {
1006 for hunk in &src.hunks {
1007 if hunk.filename == cursor.file {
1008 excerpt_first_line = hunk.new_start as usize;
1009 for line in &hunk.lines {
1010 match line {
1011 PatchLine::Addition(s) | PatchLine::Context(s) => {
1012 excerpt_lines.push(s.clone());
1013 }
1014 _ => {}
1015 }
1016 }
1017 // If hunk only has deletions, include deletion lines
1018 if excerpt_lines.is_empty() {
1019 excerpt_first_line = hunk.old_start as usize;
1020 for line in &hunk.lines {
1021 match line {
1022 PatchLine::Deletion(s) => {
1023 excerpt_lines.push(s.clone());
1024 }
1025 _ => {}
1026 }
1027 }
1028 }
1029 if !excerpt_lines.is_empty() {
1030 break;
1031 }
1032 }
1033 }
1034 }
1035
1036 if excerpt_lines.is_empty() {
1037 return None;
1038 }
1039
1040 // Add cursor marker
1041 for (i, line) in excerpt_lines.iter_mut().enumerate() {
1042 let line_num = excerpt_first_line + i;
1043 if line_num == cursor.line {
1044 let col = cursor.column.min(line.len());
1045 // Ensure we split at a valid UTF-8 character boundary
1046 let col = if line.is_char_boundary(col) {
1047 col
1048 } else {
1049 // Find the nearest valid character boundary
1050 (0..=col)
1051 .rev()
1052 .find(|&i| line.is_char_boundary(i))
1053 .unwrap_or(0)
1054 };
1055 let (before, after) = line.split_at(col);
1056 *line = format!("{}<|user_cursor|>{}", before, after);
1057 break;
1058 }
1059 }
1060
1061 Some(excerpt_lines.join("\n"))
1062}
1063
1064#[cfg(test)]
1065mod tests {
1066 use std::path::Path;
1067
1068 use edit_prediction::example_spec::ExampleSpec;
1069
1070 use super::*;
1071
1072 #[test]
1073 fn test_tokenize() {
1074 let tokens = tokenize("hello world");
1075 assert_eq!(tokens, vec!["hello", " ", "world"]);
1076
1077 let tokens = tokenize("foo_bar123 + baz");
1078 assert_eq!(tokens, vec!["foo_", "bar123", " ", "+", " ", "baz"]);
1079
1080 let tokens = tokenize("print(\"hello\")");
1081 assert_eq!(tokens, vec!["print", "(", "\"", "hello", "\"", ")"]);
1082
1083 let tokens = tokenize("hello_world");
1084 assert_eq!(tokens, vec!["hello_", "world"]);
1085
1086 let tokens = tokenize("fn();");
1087 assert_eq!(tokens, vec!["fn", "(", ")", ";"]);
1088 }
1089
1090 #[test]
1091 fn test_fuzzy_ratio() {
1092 assert_eq!(fuzzy_ratio("hello", "hello"), 100);
1093 assert_eq!(fuzzy_ratio("", ""), 100);
1094 assert!(fuzzy_ratio("hello", "world") < 50);
1095 assert!(fuzzy_ratio("hello world", "hello worl") > 80);
1096 }
1097
1098 #[test]
1099 fn test_split_ordered_commit() {
1100 let commit = r#"// First change
1101--- a/test.rs
1102+++ b/test.rs
1103@@ -1,3 +1,4 @@
1104 fn main() {
1105+ println!("hello");
1106+ println!("world");
1107 }
1108"#;
1109 let patch = Patch::parse_unified_diff(commit);
1110 let stats = patch.stats();
1111 assert_eq!(stats.added, 2);
1112
1113 let (source, target) = split_ordered_commit(commit, 1);
1114
1115 // Source should have 1 addition
1116 let src_patch = Patch::parse_unified_diff(&source);
1117 assert_eq!(src_patch.stats().added, 1);
1118
1119 // Target should have 1 addition
1120 let tgt_patch = Patch::parse_unified_diff(&target);
1121 assert_eq!(tgt_patch.stats().added, 1);
1122 }
1123
1124 #[test]
1125 fn test_split_ordered_commit_with_deletions() {
1126 let commit = r#"// Change
1127--- a/test.rs
1128+++ b/test.rs
1129@@ -1,3 +1,3 @@
1130 fn main() {
1131- println!("old");
1132+ println!("new");
1133 }
1134"#;
1135 let patch = Patch::parse_unified_diff(commit);
1136 let stats = patch.stats();
1137 assert_eq!(stats.added, 1);
1138 assert_eq!(stats.removed, 1);
1139
1140 // Split at position 1 (after the deletion)
1141 let (source, target) = split_ordered_commit(commit, 1);
1142
1143 let src_patch = Patch::parse_unified_diff(&source);
1144 let tgt_patch = Patch::parse_unified_diff(&target);
1145
1146 // Source should have the deletion
1147 assert_eq!(src_patch.stats().removed, 1);
1148 // Target should have the addition
1149 assert_eq!(tgt_patch.stats().added, 1);
1150 }
1151
1152 #[test]
1153 fn test_generate_evaluation_example() {
1154 let commit = r#"commit abc123
1155Author: Test <test@example.com>
1156Date: Mon Jan 1 00:00:00 2024
1157
1158 Test commit
1159
1160////////////////////////////////////////////////////////////////////////////////
1161// Add greeting
1162////////////////////////////////////////////////////////////////////////////////
1163--- a/test.rs
1164+++ b/test.rs
1165@@ -1,3 +1,5 @@
1166 fn main() {
1167+ println!("hello");
1168+ println!("world");
1169 }
1170"#;
1171
1172 let result = generate_evaluation_example_from_ordered_commit(
1173 commit,
1174 "https://github.com/test/repo",
1175 "abc123",
1176 Some(SplitPoint::Fraction(0.5)),
1177 Some(42),
1178 None,
1179 );
1180
1181 assert!(result.is_ok());
1182 let case = result.unwrap();
1183 assert_eq!(case.repository_url, "https://github.com/test/repo");
1184 assert_eq!(case.revision, "abc123~1");
1185 assert!(!case.edit_history.is_empty());
1186 }
1187
1188 #[test]
1189 fn test_generate_evaluation_example_reproducible() {
1190 let commit = r#"////////////////////////////////////////////////////////////////////////////////
1191// Add greeting
1192////////////////////////////////////////////////////////////////////////////////
1193--- a/test.rs
1194+++ b/test.rs
1195@@ -1,3 +1,5 @@
1196 fn main() {
1197+ println!("hello");
1198+ println!("world");
1199 }
1200"#;
1201
1202 // Run twice with the same seed
1203 let result1 = generate_evaluation_example_from_ordered_commit(
1204 commit,
1205 "https://github.com/test/repo",
1206 "abc123",
1207 Some(SplitPoint::Fraction(0.5)),
1208 Some(12345),
1209 None,
1210 )
1211 .unwrap();
1212
1213 let result2 = generate_evaluation_example_from_ordered_commit(
1214 commit,
1215 "https://github.com/test/repo",
1216 "abc123",
1217 Some(SplitPoint::Fraction(0.5)),
1218 Some(12345),
1219 None,
1220 )
1221 .unwrap();
1222
1223 // Results should be identical
1224 assert_eq!(result1.edit_history, result2.edit_history);
1225 assert_eq!(result1.expected_patches, result2.expected_patches);
1226 assert_eq!(result1.cursor_position, result2.cursor_position);
1227 }
1228
1229 #[test]
1230 fn test_cursor_position_display() {
1231 let cursor = CursorPosition {
1232 file: "src/main.rs".to_string(),
1233 line: 42,
1234 column: 10,
1235 };
1236 assert_eq!(cursor.to_string(), "src/main.rs:42:10");
1237 }
1238
1239 #[test]
1240 fn test_imitate_human_edits_no_change_when_no_replacement() {
1241 // Source and target patches that don't form a replacement pattern
1242 let source = r#"--- a/test.rs
1243+++ b/test.rs
1244@@ -1,3 +1,4 @@
1245 fn main() {
1246+ println!("hello");
1247 }
1248"#;
1249 let target = r#"--- a/test.rs
1250+++ b/test.rs
1251@@ -1,3 +1,4 @@
1252 fn main() {
1253+ println!("world");
1254 }
1255"#;
1256
1257 let (new_src, new_tgt, cursor) = imitate_human_edits(source, target, 42);
1258
1259 // Should return unchanged when not a replacement pattern
1260 assert_eq!(new_src, source);
1261 assert_eq!(new_tgt, target);
1262 assert!(cursor.is_none());
1263 }
1264
1265 #[test]
1266 fn test_split_point_fraction() {
1267 let commit = r#"// Change
1268--- a/test.rs
1269+++ b/test.rs
1270@@ -1,5 +1,10 @@
1271 fn main() {
1272+ line1();
1273+ line2();
1274+ line3();
1275+ line4();
1276+ line5();
1277 }
1278"#;
1279
1280 // Split at 20% should give first edit in source
1281 let result = generate_evaluation_example_from_ordered_commit(
1282 commit,
1283 "",
1284 "hash",
1285 Some(SplitPoint::Fraction(0.2)),
1286 Some(1),
1287 None,
1288 );
1289
1290 assert!(result.is_ok());
1291 let case = result.unwrap();
1292
1293 // Source should have some edits
1294 let src_patch = Patch::parse_unified_diff(&case.edit_history);
1295 assert!(src_patch.stats().added > 0);
1296 }
1297
1298 #[test]
1299 fn test_split_point_index() {
1300 let commit = r#"// Change
1301--- a/test.rs
1302+++ b/test.rs
1303@@ -1,5 +1,10 @@
1304 fn main() {
1305+ line1();
1306+ line2();
1307+ line3();
1308+ line4();
1309+ line5();
1310 }
1311"#;
1312
1313 // Split at index 2 should give first 2 edits in source
1314 // With pure insertion handling, source gets 2 original + 1 partial = 3 additions
1315 let result = generate_evaluation_example_from_ordered_commit(
1316 commit,
1317 "",
1318 "hash",
1319 Some(SplitPoint::Index(2)),
1320 Some(1),
1321 None,
1322 );
1323
1324 assert!(result.is_ok());
1325 let case = result.unwrap();
1326
1327 let src_patch = Patch::parse_unified_diff(&case.edit_history);
1328 // Pure insertion adds a partial line, so we expect 3 (2 original + 1 partial)
1329 assert_eq!(src_patch.stats().added, 3);
1330 }
1331
1332 #[test]
1333 fn test_cursor_excerpt_contains_marker() {
1334 let commit = r#"////////////////////////////////////////////////////////////////////////////////
1335// Add code
1336////////////////////////////////////////////////////////////////////////////////
1337--- a/test.rs
1338+++ b/test.rs
1339@@ -1,3 +1,5 @@
1340 fn main() {
1341+ println!("hello");
1342+ println!("world");
1343 }
1344"#;
1345
1346 let result = generate_evaluation_example_from_ordered_commit(
1347 commit,
1348 "",
1349 "hash",
1350 Some(SplitPoint::Fraction(0.5)),
1351 Some(42),
1352 None,
1353 )
1354 .unwrap();
1355
1356 // Cursor excerpt should contain the cursor marker
1357 assert!(
1358 result.cursor_position.contains("<|user_cursor|>"),
1359 "Cursor excerpt should contain marker: {}",
1360 result.cursor_position
1361 );
1362 }
1363
1364 #[test]
1365 fn test_evaluation_case_json_serialization() {
1366 let case = ExampleSpec {
1367 name: "test-abc123".to_string(),
1368 repository_url: "https://github.com/test/repo".to_string(),
1369 revision: "abc123~1".to_string(),
1370 edit_history: "patch1".to_string(),
1371 // cursor_position: "file.rs:10:5".to_string(),
1372 cursor_path: Path::new("file.rs").into(),
1373 cursor_position: "some code<|user_cursor|>".to_string(),
1374 expected_patches: vec!["patch".to_string()],
1375 tags: vec![],
1376 reasoning: None,
1377 uncommitted_diff: String::new(),
1378 rejected_patch: None,
1379 };
1380
1381 let json = serde_json::to_string(&case).unwrap();
1382 let deserialized: ExampleSpec = serde_json::from_str(&json).unwrap();
1383
1384 assert_eq!(case.repository_url, deserialized.repository_url);
1385 assert_eq!(case.revision, deserialized.revision);
1386 assert_eq!(case.cursor_position, deserialized.cursor_position);
1387 }
1388
1389 #[test]
1390 fn test_empty_commit_returns_error() {
1391 let commit = "";
1392
1393 let result = generate_evaluation_example_from_ordered_commit(
1394 commit,
1395 "",
1396 "hash",
1397 Some(SplitPoint::Fraction(0.5)),
1398 Some(1),
1399 None,
1400 );
1401
1402 assert!(result.is_err());
1403 }
1404
1405 #[test]
1406 fn test_header_filtering() {
1407 let commit = r#"commit abc123
1408Author: Test
1409Date: Today
1410
1411 Message
1412
1413diff --git a/test.rs b/test.rs
1414index 123..456 789
1415////////////////////////////////////////////////////////////////////////////////
1416// First group
1417////////////////////////////////////////////////////////////////////////////////
1418--- a/test.rs
1419+++ b/test.rs
1420@@ -1,3 +1,4 @@
1421 fn main() {
1422+ code();
1423 }
1424"#;
1425
1426 let result = generate_evaluation_example_from_ordered_commit(
1427 commit,
1428 "",
1429 "hash",
1430 Some(SplitPoint::Index(1)),
1431 Some(1),
1432 None,
1433 );
1434
1435 assert!(result.is_ok());
1436 let case = result.unwrap();
1437
1438 // The edit history should contain the group header (// lines)
1439 // but not the commit metadata
1440 assert!(!case.edit_history.contains("Author:"));
1441 assert!(!case.edit_history.contains("Date:"));
1442 }
1443
1444 #[test]
1445 fn test_position_weight() {
1446 // High weight positions (natural pause points)
1447 assert_eq!(position_weight("foo(", 4), 10); // After '('
1448 assert_eq!(position_weight("a, b", 2), 10); // After ','
1449 assert_eq!(position_weight("x;", 2), 10); // After ';'
1450 assert_eq!(position_weight("a: b", 2), 10); // After ':'
1451 assert_eq!(position_weight("[", 1), 10); // After '['
1452 assert_eq!(position_weight("{", 1), 10); // After '{'
1453
1454 // High weight for closing brackets
1455 assert_eq!(position_weight("foo)", 4), 8); // After ')'
1456 assert_eq!(position_weight("]", 1), 8); // After ']'
1457 assert_eq!(position_weight("}", 1), 8); // After '}'
1458
1459 // High weight at end of identifier
1460 assert_eq!(position_weight("foo ", 3), 8); // End of 'foo' before space
1461 assert_eq!(position_weight("bar(", 3), 8); // End of 'bar' before '('
1462
1463 // Medium weight for operators
1464 assert_eq!(position_weight("a + b", 3), 5); // After '+'
1465 assert_eq!(position_weight("x.", 2), 5); // After '.'
1466 assert_eq!(position_weight("a=b", 2), 5); // After '='
1467
1468 // Medium weight for whitespace
1469 assert_eq!(position_weight("a ", 2), 6); // After space
1470
1471 // Low weight mid-identifier
1472 assert_eq!(position_weight("foobar", 3), 1); // Mid-identifier 'foo|bar'
1473
1474 // Edge cases
1475 assert_eq!(position_weight("", 0), 1); // Empty string
1476 assert_eq!(position_weight("a", 0), 1); // Position 0
1477 }
1478
1479 #[test]
1480 fn test_weighted_select() {
1481 // Test that weighted selection returns correct indices
1482 let weights = vec![1, 10, 1];
1483
1484 // With total weight 12, seed 0 should select index 0
1485 // seed 0 % 12 = 0, cumulative: 1 at idx 0, so returns 0
1486 assert_eq!(weighted_select(&weights, 0), 0);
1487
1488 // seed 1 % 12 = 1, cumulative: 1 at idx 0 (1 < 1 is false), 11 at idx 1 (1 < 11 is true)
1489 assert_eq!(weighted_select(&weights, 1), 1);
1490
1491 // seed 10 % 12 = 10, cumulative: 1, 11 at idx 1 (10 < 11 is true)
1492 assert_eq!(weighted_select(&weights, 10), 1);
1493
1494 // seed 11 % 12 = 11, cumulative: 1, 11 at idx 1 (11 < 11 is false), 12 at idx 2 (11 < 12 is true)
1495 assert_eq!(weighted_select(&weights, 11), 2);
1496
1497 // Empty weights should return 0
1498 let empty: Vec<u32> = vec![];
1499 assert_eq!(weighted_select(&empty, 42), 0);
1500
1501 // Single weight should always return index 0
1502 let single = vec![10];
1503 assert_eq!(weighted_select(&single, 0), 0);
1504 assert_eq!(weighted_select(&single, 100), 0);
1505 }
1506
1507 #[test]
1508 fn test_weighted_split_prefers_natural_boundaries() {
1509 // Test that with different seeds, weighted selection tends to prefer
1510 // positions after punctuation over mid-identifier positions
1511 let text_with_punctuation = "foo(bar, baz)";
1512 let text_mid_identifier = "foobar";
1513
1514 // Position after '(' should have high weight
1515 let weight_after_paren = position_weight(text_with_punctuation, 4);
1516 // Position after ',' should have high weight
1517 let weight_after_comma = position_weight(text_with_punctuation, 8);
1518 // Position mid-identifier should have low weight
1519 let weight_mid_ident = position_weight(text_mid_identifier, 3);
1520
1521 assert!(
1522 weight_after_paren > weight_mid_ident,
1523 "After '(' ({}) should be weighted higher than mid-identifier ({})",
1524 weight_after_paren,
1525 weight_mid_ident
1526 );
1527 assert!(
1528 weight_after_comma > weight_mid_ident,
1529 "After ',' ({}) should be weighted higher than mid-identifier ({})",
1530 weight_after_comma,
1531 weight_mid_ident
1532 );
1533 }
1534
1535 #[test]
1536 fn test_imitate_human_edits_pure_insertion() {
1537 // Source patch is empty (no edits yet)
1538 // Target patch has a pure insertion (adding a new line)
1539 let source = r#"--- a/test.rs
1540+++ b/test.rs
1541@@ -1,2 +1,2 @@
1542 fn main() {
1543 }
1544"#;
1545 let target = r#"--- a/test.rs
1546+++ b/test.rs
1547@@ -1,2 +1,3 @@
1548 fn main() {
1549+ println!("debug");
1550 }
1551"#;
1552
1553 let (new_src, new_tgt, cursor) = imitate_human_edits(source, target, 42);
1554
1555 // Should have transformed the patches
1556 assert_ne!(
1557 new_src, source,
1558 "Source should be modified for pure insertion"
1559 );
1560 assert_ne!(
1561 new_tgt, target,
1562 "Target should be modified for pure insertion"
1563 );
1564 assert!(cursor.is_some(), "Cursor should be set");
1565
1566 // Source should now have a partial addition
1567 let src_patch = Patch::parse_unified_diff(&new_src);
1568 assert!(
1569 src_patch.stats().added > 0,
1570 "Source should have added lines"
1571 );
1572
1573 // Target should have both a deletion (of partial) and addition (of full)
1574 let tgt_patch = Patch::parse_unified_diff(&new_tgt);
1575 assert!(
1576 tgt_patch.stats().removed > 0,
1577 "Target should have removed lines (partial)"
1578 );
1579 assert!(
1580 tgt_patch.stats().added > 0,
1581 "Target should have added lines (full)"
1582 );
1583
1584 // The cursor should be in test.rs
1585 let cursor = cursor.unwrap();
1586 assert_eq!(cursor.file, "test.rs");
1587 }
1588
1589 #[test]
1590 fn test_imitate_human_edits_pure_insertion_empty_source() {
1591 // Source patch has no hunks at all
1592 let source = "";
1593 let target = r#"--- a/test.rs
1594+++ b/test.rs
1595@@ -1,2 +1,3 @@
1596 fn main() {
1597+ println!("hello");
1598 }
1599"#;
1600
1601 let (new_src, _new_tgt, cursor) = imitate_human_edits(source, target, 123);
1602
1603 // Should have created a source patch with partial insertion
1604 assert!(!new_src.is_empty(), "Source should not be empty");
1605 assert!(cursor.is_some(), "Cursor should be set");
1606
1607 let src_patch = Patch::parse_unified_diff(&new_src);
1608 assert!(
1609 src_patch.stats().added > 0,
1610 "Source should have added lines"
1611 );
1612 }
1613
1614 #[test]
1615 fn test_imitate_human_edits_pure_insertion_intermediate_content() {
1616 // Verify the actual intermediate content is a realistic partial typing state
1617 let source = "";
1618 let target = r#"--- a/test.rs
1619+++ b/test.rs
1620@@ -1,2 +1,3 @@
1621 fn main() {
1622+ println!("hello world");
1623 }
1624"#;
1625
1626 // Test with multiple seeds to see different split points
1627 let mut found_partial = false;
1628 for seed in 1..=50 {
1629 let (new_src, new_tgt, cursor) = imitate_human_edits(source, target, seed);
1630
1631 if cursor.is_some() {
1632 let src_patch = Patch::parse_unified_diff(&new_src);
1633 let tgt_patch = Patch::parse_unified_diff(&new_tgt);
1634
1635 // Find the added line in source
1636 for hunk in &src_patch.hunks {
1637 for line in &hunk.lines {
1638 if let PatchLine::Addition(content) = line {
1639 // The partial line should be a prefix of the full line
1640 let full_line = " println!(\"hello world\");";
1641 if content != full_line && full_line.starts_with(content) {
1642 found_partial = true;
1643
1644 // Verify target has the partial as deletion
1645 let mut has_deletion = false;
1646 for tgt_hunk in &tgt_patch.hunks {
1647 for tgt_line in &tgt_hunk.lines {
1648 if let PatchLine::Deletion(del_content) = tgt_line {
1649 if del_content == content {
1650 has_deletion = true;
1651 }
1652 }
1653 }
1654 }
1655 assert!(
1656 has_deletion,
1657 "Target should have deletion of partial line"
1658 );
1659 }
1660 }
1661 }
1662 }
1663 }
1664 }
1665
1666 assert!(
1667 found_partial,
1668 "At least one seed should produce a partial intermediate state"
1669 );
1670 }
1671
1672 #[test]
1673 fn test_imitate_human_edits_inserts_after_last_source_edit() {
1674 // Regression test: intermediate content should appear after the last edit
1675 // in the source patch, not at the position of the first target edit.
1676 // This ensures the diff output correctly imitates human typing order.
1677 //
1678 // The bug was: when source has edits and target has a pure insertion,
1679 // the intermediate content was inserted at tgt_edit_loc.line_index_within_hunk
1680 // (position of first target edit) instead of after the last source edit.
1681 //
1682 // Source patch has edits at lines 1-4, target has a new edit at line 10
1683 // (different location to avoid the "same line" early return)
1684 let source = r#"--- a/test.py
1685+++ b/test.py
1686@@ -1,4 +1,5 @@
1687+import foo
1688 import bar
1689-import old
1690 import baz
1691+import qux
1692"#;
1693 // Target has a pure insertion at a different line (line 10, not overlapping with source)
1694 let target = r#"--- a/test.py
1695+++ b/test.py
1696@@ -10,3 +10,4 @@
1697 def main():
1698+ print("hello world")
1699 pass
1700"#;
1701
1702 // Use a seed that produces a partial result
1703 let (new_src, _new_tgt, cursor) = imitate_human_edits(source, target, 42);
1704
1705 // The function should produce a modified patch
1706 assert!(cursor.is_some(), "Should produce intermediate state");
1707
1708 let src_patch = Patch::parse_unified_diff(&new_src);
1709 let all_additions: Vec<_> = src_patch
1710 .hunks
1711 .iter()
1712 .flat_map(|h| h.lines.iter())
1713 .filter_map(|l| match l {
1714 PatchLine::Addition(s) => Some(s.as_str()),
1715 _ => None,
1716 })
1717 .collect();
1718
1719 // The intermediate content (partial 'print("hello world")') should be
1720 // the LAST addition, appearing after "+import qux" (the last source edit)
1721 let last_addition = all_additions.last().expect("Should have additions");
1722 assert!(
1723 last_addition.trim_start().starts_with("pr"),
1724 "Intermediate content should be the last addition (partial 'print'), but last was: {:?}",
1725 last_addition
1726 );
1727
1728 // Verify the original source edits are still in order before the intermediate
1729 let foo_pos = all_additions.iter().position(|s| *s == "import foo");
1730 let qux_pos = all_additions.iter().position(|s| *s == "import qux");
1731 let intermediate_pos = all_additions
1732 .iter()
1733 .position(|s| s.trim_start().starts_with("pr"));
1734
1735 assert!(foo_pos.is_some(), "Should have 'import foo'");
1736 assert!(qux_pos.is_some(), "Should have 'import qux'");
1737 assert!(
1738 intermediate_pos.is_some(),
1739 "Should have intermediate content"
1740 );
1741
1742 assert!(
1743 foo_pos < qux_pos && qux_pos < intermediate_pos,
1744 "Order should be: foo < qux < intermediate. Got foo={:?}, qux={:?}, intermediate={:?}",
1745 foo_pos,
1746 qux_pos,
1747 intermediate_pos
1748 );
1749 }
1750
1751 #[test]
1752 fn test_cursor_excerpt_with_multibyte_utf8() {
1753 // Test that cursor excerpt handles multi-byte UTF-8 characters correctly
1754 // The Chinese character '第' is 3 bytes (0..3)
1755 let cursor = CursorPosition {
1756 file: "test.md".to_string(),
1757 line: 1,
1758 column: 1, // Byte index 1 is inside '第' (bytes 0..3)
1759 };
1760
1761 let source_patch = r#"--- a/test.md
1762+++ b/test.md
1763@@ -1,1 +1,1 @@
1764+第 14 章 Flask 工作原理与机制解析**
1765"#;
1766
1767 let target_patch = "";
1768
1769 // This should not panic even though column=1 is not a char boundary
1770 let result = get_cursor_excerpt(&cursor, source_patch, target_patch);
1771
1772 // The function should handle the invalid byte index gracefully
1773 if let Some(excerpt) = result {
1774 assert!(
1775 excerpt.contains("<|user_cursor|>"),
1776 "Cursor excerpt should contain marker"
1777 );
1778 // The marker should be placed at a valid character boundary
1779 // (either at the start or after '第')
1780 }
1781 }
1782}