1//! `ep split-commit` implementation.
2//!
3//! This command generates a single evaluation example JSON object from a
4//! chronologically-ordered unified diff (a "commit").
5//!
6//! TODO: Port Python code to generate chronologically-ordered commits
7use crate::FailedHandling;
8use crate::reorder_patch::{Patch, PatchLine, extract_edits, locate_edited_line};
9use crate::word_diff::tokenize;
10
11/// Find the largest valid UTF-8 char boundary at or before `index` in `s`.
12fn floor_char_boundary(s: &str, index: usize) -> usize {
13 if index >= s.len() {
14 s.len()
15 } else if s.is_char_boundary(index) {
16 index
17 } else {
18 // Find the nearest valid character boundary at or before index
19 (0..index)
20 .rev()
21 .find(|&i| s.is_char_boundary(i))
22 .unwrap_or(0)
23 }
24}
25use anyhow::{Context as _, Result};
26use clap::Args;
27use edit_prediction::example_spec::ExampleSpec;
28use rand::Rng;
29use rand::SeedableRng;
30use serde::{Deserialize, Serialize};
31use similar::{DiffTag, TextDiff};
32use std::collections::BTreeSet;
33use std::fs;
34use std::io::{self, Write};
35use std::path::Path;
36use std::path::PathBuf;
37
38/// `ep split-commit` CLI args.
39#[derive(Debug, Args, Clone)]
40pub struct SplitCommitArgs {
41 /// Split point (float 0.0-1.0 for fraction, or integer for index)
42 #[arg(long, short = 's')]
43 pub split_point: Option<String>,
44
45 /// Random seed for reproducibility
46 #[arg(long)]
47 pub seed: Option<u64>,
48
49 /// Pretty-print JSON output
50 #[arg(long, short = 'p')]
51 pub pretty: bool,
52
53 /// Number of samples to generate per commit (samples random split points)
54 #[arg(long, short = 'n')]
55 pub num_samples: Option<usize>,
56}
57
58/// Input format for annotated commits (JSON Lines).
59#[derive(Debug, Clone, Deserialize)]
60#[allow(dead_code)]
61pub struct AnnotatedCommit {
62 /// Repository path (e.g., "repos/zed")
63 pub repo: String,
64 /// Repository URL (e.g., "https://github.com/zed-industries/zed")
65 pub repo_url: String,
66 /// Commit SHA
67 pub commit_sha: String,
68 /// Chronologically reordered commit diff
69 pub reordered_commit: String,
70 /// Original commit diff
71 pub original_commit: String,
72 /// Whether diff stats match between original and reordered
73 pub diff_stats_match: bool,
74}
75
76/// Cursor position in a file.
77#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
78pub struct CursorPosition {
79 pub file: String,
80 pub line: usize,
81 pub column: usize,
82}
83
84impl std::fmt::Display for CursorPosition {
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 write!(f, "{}:{}:{}", self.file, self.line, self.column)
87 }
88}
89
90/// Represents a split commit with source and target patches.
91#[derive(Debug, Clone)]
92pub struct SplitCommit {
93 pub source_patch: String,
94 pub target_patch: String,
95}
96
97/// Split point specification for evaluation generation.
98#[derive(Debug, Clone)]
99pub enum SplitPoint {
100 /// Fraction of total edits (0.0 to 1.0)
101 Fraction(f64),
102 /// Absolute index
103 Index(usize),
104}
105
106fn parse_split_point(value: &str) -> Option<SplitPoint> {
107 if value.contains('.') {
108 value.parse::<f64>().ok().map(SplitPoint::Fraction)
109 } else {
110 value.parse::<usize>().ok().map(SplitPoint::Index)
111 }
112}
113
114/// Entry point for the `ep split-commit` subcommand.
115///
116/// This runs synchronously and outputs JSON Lines (one output per input line).
117pub fn run_split_commit(
118 args: &SplitCommitArgs,
119 inputs: &[PathBuf],
120 output_path: Option<&PathBuf>,
121 failed: FailedHandling,
122) -> Result<()> {
123 use std::collections::HashSet;
124 use std::io::BufRead;
125
126 let stdin_path = PathBuf::from("-");
127 let inputs = if inputs.is_empty() {
128 std::slice::from_ref(&stdin_path)
129 } else {
130 inputs
131 };
132
133 let split_point = args.split_point.as_deref().and_then(parse_split_point);
134 let mut output_lines = Vec::new();
135
136 for input_path in inputs {
137 let input: Box<dyn BufRead> = if input_path.as_os_str() == "-" {
138 Box::new(io::BufReader::new(io::stdin()))
139 } else {
140 let file = fs::File::open(input_path)
141 .with_context(|| format!("failed to open input file {}", input_path.display()))?;
142 Box::new(io::BufReader::new(file))
143 };
144
145 for (line_num, line_result) in input.lines().enumerate() {
146 let line =
147 line_result.with_context(|| format!("failed to read line {}", line_num + 1))?;
148
149 if line.trim().is_empty() {
150 continue;
151 }
152
153 let annotated: AnnotatedCommit = serde_json::from_str(&line)
154 .with_context(|| format!("failed to parse JSON at line {}", line_num + 1))?;
155
156 // Generate multiple samples if num_samples is set
157 if let Some(num_samples) = args.num_samples {
158 let mut seen_samples: HashSet<String> = HashSet::new();
159 let base_seed = args.seed.unwrap_or_else(|| rand::random());
160
161 for sample_idx in 0..num_samples {
162 let sample_seed = base_seed.wrapping_add(sample_idx as u64);
163
164 let case = match generate_evaluation_example_from_ordered_commit(
165 &annotated.reordered_commit,
166 &annotated.repo_url,
167 &annotated.commit_sha,
168 None, // Use random split point for multi-sample mode
169 Some(sample_seed),
170 Some(sample_idx),
171 ) {
172 Ok(case) => case,
173 Err(e) => {
174 let err_msg = format!(
175 "failed to generate evaluation example for commit {} at line {} (sample {}): {}",
176 annotated.commit_sha,
177 line_num + 1,
178 sample_idx,
179 e
180 );
181 match failed {
182 FailedHandling::Skip | FailedHandling::SkipNoFiles => {
183 eprintln!("{}", err_msg);
184 continue;
185 }
186 FailedHandling::Keep => {
187 anyhow::bail!(err_msg);
188 }
189 }
190 }
191 };
192
193 let json = if args.pretty {
194 serde_json::to_string_pretty(&case)
195 } else {
196 serde_json::to_string(&case)
197 }
198 .context("failed to serialize evaluation case as JSON")?;
199
200 // Only add unique samples (different split points may produce same result)
201 if seen_samples.insert(json.clone()) {
202 output_lines.push(json);
203 }
204 }
205 } else {
206 let case = match generate_evaluation_example_from_ordered_commit(
207 &annotated.reordered_commit,
208 &annotated.repo_url,
209 &annotated.commit_sha,
210 split_point.clone(),
211 args.seed,
212 None,
213 ) {
214 Ok(case) => case,
215 Err(e) => {
216 let err_msg = format!(
217 "failed to generate evaluation example for commit {} at line {}: {}",
218 annotated.commit_sha,
219 line_num + 1,
220 e
221 );
222 match failed {
223 FailedHandling::Skip | FailedHandling::SkipNoFiles => {
224 eprintln!("{}", err_msg);
225 continue;
226 }
227 FailedHandling::Keep => {
228 anyhow::bail!(err_msg);
229 }
230 }
231 }
232 };
233
234 let json = if args.pretty {
235 serde_json::to_string_pretty(&case)
236 } else {
237 serde_json::to_string(&case)
238 }
239 .context("failed to serialize evaluation case as JSON")?;
240
241 output_lines.push(json);
242 }
243 }
244 }
245
246 let output_content = output_lines.join("\n") + if output_lines.is_empty() { "" } else { "\n" };
247
248 if let Some(path) = output_path {
249 fs::write(path, &output_content)
250 .with_context(|| format!("failed to write output to {}", path.display()))?;
251 } else {
252 io::stdout()
253 .write_all(output_content.as_bytes())
254 .context("failed to write to stdout")?;
255 }
256
257 Ok(())
258}
259
260/// Main function to generate an evaluation example from an ordered commit.
261///
262/// # Arguments
263/// * `commit` - Chronologically ordered unified diff of the commit
264/// * `repository_url` - URL of the repository
265/// * `commit_hash` - Hash of the commit
266/// * `split_point` - Point at which the commit will be split (None for random)
267/// * `seed` - Optional seed for randomness
268/// * `sample_num` - Optional sample number for generating unique names
269pub fn generate_evaluation_example_from_ordered_commit(
270 commit: &str,
271 repository_url: &str,
272 commit_hash: &str,
273 split_point: Option<SplitPoint>,
274 seed: Option<u64>,
275 sample_num: Option<usize>,
276) -> Result<ExampleSpec> {
277 let mut rng: Box<dyn rand::RngCore> = match seed {
278 Some(seed) => Box::new(rand::rngs::StdRng::seed_from_u64(seed)),
279 None => Box::new(rand::rngs::ThreadRng::default()),
280 };
281
282 // Parse and normalize the commit
283 let mut patch = Patch::parse_unified_diff(commit);
284
285 // Filter header to only keep lines starting with "//"
286 let header_lines: Vec<&str> = patch
287 .header
288 .lines()
289 .filter(|line| line.starts_with("//"))
290 .collect();
291 patch.header = if header_lines.is_empty() {
292 String::new()
293 } else {
294 header_lines.join("\n") + "\n"
295 };
296 let commit_normalized = patch.to_string();
297
298 // Compute the split point
299 let stats = patch.stats();
300 let num_edits = stats.added + stats.removed;
301
302 anyhow::ensure!(num_edits != 0, "no edits found in commit");
303
304 let split = match split_point {
305 None => rng.random_range(1..=num_edits),
306 Some(SplitPoint::Fraction(f)) => {
307 let v = (f * num_edits as f64).floor() as usize;
308 v.min(num_edits)
309 }
310 Some(SplitPoint::Index(i)) => i.min(num_edits),
311 };
312
313 // Split the commit into source and target patches
314 let (prefix, suffix) = split_ordered_commit(&commit_normalized, split);
315
316 let mut split_commit = SplitCommit {
317 source_patch: prefix,
318 target_patch: suffix,
319 };
320
321 // Imitate human edits
322 let human_edit_seed = rng.random_range(1..=10000u64);
323 let (src_patch, tgt_patch, cursor_opt) = imitate_human_edits(
324 &split_commit.source_patch,
325 &split_commit.target_patch,
326 human_edit_seed,
327 );
328 split_commit.source_patch = src_patch;
329 split_commit.target_patch = tgt_patch;
330
331 // Sample cursor position
332 let cursor = match cursor_opt {
333 Some(c) => c,
334 None => sample_cursor_position(&patch, &split_commit)
335 .context("failed to sample cursor position")?,
336 };
337
338 // Get cursor excerpt
339 let cursor_excerpt = get_cursor_excerpt(
340 &cursor,
341 &split_commit.source_patch,
342 &split_commit.target_patch,
343 )
344 .context("failed to generate cursor excerpt")?;
345
346 // Handle edge case where split_point == 0
347 if split == 0 {
348 split_commit.target_patch = String::new();
349 }
350
351 let repo_name = repository_url
352 .trim_end_matches('/')
353 .rsplit('/')
354 .next()
355 .unwrap_or("unknown");
356 let short_sha = &commit_hash[..commit_hash.len().min(8)];
357 let name = match sample_num {
358 Some(n) => format!("{}-{}-{}", repo_name, short_sha, n),
359 None => format!("{}-{}", repo_name, short_sha),
360 };
361
362 Ok(ExampleSpec {
363 name,
364 repository_url: repository_url.to_string(),
365 revision: format!("{}~1", commit_hash),
366 edit_history: split_commit.source_patch.clone(),
367 cursor_path: Path::new(&cursor.file).into(),
368 cursor_position: cursor_excerpt,
369 expected_patches: vec![split_commit.target_patch],
370 tags: vec![],
371 reasoning: None,
372 uncommitted_diff: String::new(),
373 rejected_patch: None,
374
375 telemetry: None,
376 human_feedback: Vec::new(),
377 rating: None,
378 })
379}
380
381/// Split an ordered commit into source and target commits.
382///
383/// # Arguments
384/// * `commit` - Ordered commit string
385/// * `split_pos` - Position to split the commit (number of edited lines)
386///
387/// # Returns
388/// A tuple of (source_diff, target_diff)
389pub fn split_ordered_commit(commit: &str, split_pos: usize) -> (String, String) {
390 let patch = Patch::parse_unified_diff(commit);
391 let source_edits: BTreeSet<usize> = (0..split_pos).collect();
392 let (source, target) = extract_edits(&patch, &source_edits);
393
394 let mut source_str = source.to_string();
395 let target_str = target.to_string();
396
397 // Strip last group header from the source (lines starting with "//" at the end)
398 let source_lines: Vec<&str> = source_str.lines().collect();
399 let mut end_idx = source_lines.len();
400 for i in (0..source_lines.len()).rev() {
401 if source_lines[i].starts_with("//") {
402 end_idx = i;
403 } else {
404 break;
405 }
406 }
407 if end_idx < source_lines.len() {
408 source_str = source_lines[..end_idx].join("\n");
409 if !source_str.is_empty() {
410 source_str.push('\n');
411 }
412 }
413
414 (source_str, target_str)
415}
416
417/// Calculate the weight for a split position based on the character at that position.
418///
419/// Higher weights indicate more natural pause points (e.g., after punctuation,
420/// at identifier boundaries). Lower weights indicate less natural points
421/// (e.g., mid-identifier).
422fn position_weight(text: &str, pos: usize) -> u32 {
423 if pos == 0 || pos > text.len() {
424 return 1;
425 }
426
427 let chars: Vec<char> = text.chars().collect();
428 if pos > chars.len() {
429 return 1;
430 }
431
432 // Get the character just before this position (what we just "typed")
433 let prev_char = chars[pos - 1];
434
435 // High weight: natural pause points (end of statement/argument, opening brackets)
436 if matches!(prev_char, ',' | ';' | ':' | '(' | '[' | '{') {
437 return 10;
438 }
439
440 // High weight: closing brackets (finished a group)
441 if matches!(prev_char, ')' | ']' | '}') {
442 return 8;
443 }
444
445 // Medium weight: operators and method chains
446 if matches!(
447 prev_char,
448 '.' | '+' | '-' | '*' | '/' | '=' | '<' | '>' | '&' | '|' | '!'
449 ) {
450 return 5;
451 }
452
453 // Check if we're at the end of an identifier (word char followed by non-word char)
454 let is_prev_word_char = prev_char.is_alphanumeric() || prev_char == '_';
455 let is_next_word_char =
456 pos < chars.len() && (chars[pos].is_alphanumeric() || chars[pos] == '_');
457
458 if is_prev_word_char && !is_next_word_char {
459 // End of identifier - high weight
460 return 8;
461 }
462
463 // Whitespace is a natural pause
464 if prev_char.is_whitespace() {
465 return 6;
466 }
467
468 // Mid-identifier: low weight (rare autocomplete scenarios)
469 if is_prev_word_char && is_next_word_char {
470 return 1;
471 }
472
473 // Default medium-low weight
474 3
475}
476
477/// Select a weighted random index from a list of weights.
478///
479/// Returns an index based on the weights, using the provided seed for
480/// deterministic selection.
481fn weighted_select(weights: &[u32], seed: u64) -> usize {
482 if weights.is_empty() {
483 return 0;
484 }
485
486 let total_weight: u64 = weights.iter().map(|&w| w as u64).sum();
487 if total_weight == 0 {
488 // Fallback to uniform selection if all weights are zero
489 return seed as usize % weights.len();
490 }
491
492 // Use seed to select a value in [0, total_weight)
493 let target = seed % total_weight;
494 let mut cumulative: u64 = 0;
495
496 for (idx, &weight) in weights.iter().enumerate() {
497 cumulative += weight as u64;
498 if target < cumulative {
499 return idx;
500 }
501 }
502
503 // Fallback to last index
504 weights.len() - 1
505}
506
507/// Calculate similarity ratio between two strings (0-100).
508fn fuzzy_ratio(s1: &str, s2: &str) -> u32 {
509 if s1.is_empty() && s2.is_empty() {
510 return 100;
511 }
512 if s1.is_empty() || s2.is_empty() {
513 return 0;
514 }
515
516 let diff = TextDiff::from_chars(s1, s2);
517 let matching: usize = diff
518 .ops()
519 .iter()
520 .filter_map(|op| {
521 if matches!(op.tag(), DiffTag::Equal) {
522 Some(op.new_range().len())
523 } else {
524 None
525 }
526 })
527 .sum();
528
529 let total = s1.len() + s2.len();
530 ((2 * matching * 100) / total) as u32
531}
532
533/// Imitate human edits by introducing partial line edits.
534///
535/// This function simulates how a human might incrementally type code,
536/// rather than making complete line replacements.
537pub fn imitate_human_edits(
538 source_patch: &str,
539 target_patch: &str,
540 seed: u64,
541) -> (String, String, Option<CursorPosition>) {
542 let no_change = (source_patch.to_string(), target_patch.to_string(), None);
543
544 let src_patch = Patch::parse_unified_diff(source_patch);
545 let tgt_patch = Patch::parse_unified_diff(target_patch);
546
547 if tgt_patch.hunks.is_empty() {
548 return no_change;
549 }
550
551 // Try to locate the first edit in target
552 let tgt_edit_loc = match locate_edited_line(&tgt_patch, 0) {
553 Some(loc) => loc,
554 None => return no_change,
555 };
556
557 let tgt_is_addition = matches!(tgt_edit_loc.patch_line, PatchLine::Addition(_));
558 if !tgt_is_addition {
559 return no_change;
560 }
561
562 let tgt_line = match &tgt_edit_loc.patch_line {
563 PatchLine::Addition(s) => s.clone(),
564 _ => return no_change,
565 };
566
567 // Try to locate the last edit in source
568 let src_edit_loc = locate_edited_line(&src_patch, -1);
569
570 // Check if source has ANY edit at the same line as target's first edit
571 // We need to iterate through all edits to check this
572 let src_has_edit_at_target_line = {
573 let mut found = false;
574 let mut idx = 0isize;
575 while let Some(loc) = locate_edited_line(&src_patch, idx) {
576 if loc.filename == tgt_edit_loc.filename
577 && loc.target_line_number == tgt_edit_loc.target_line_number
578 {
579 found = true;
580 break;
581 }
582 idx += 1;
583 }
584 found
585 };
586
587 // Check if this is a replacement (deletion followed by insertion on the same line)
588 // or a pure insertion (no corresponding deletion in source)
589 let is_replacement = src_edit_loc.as_ref().map_or(false, |loc| {
590 matches!(loc.patch_line, PatchLine::Deletion(_))
591 && loc.filename == tgt_edit_loc.filename
592 && loc.target_line_number == tgt_edit_loc.target_line_number
593 });
594
595 // If source has an edit at the same line but it's not a replacement (i.e., it's an addition),
596 // we shouldn't process this as a pure insertion either
597 if !is_replacement && src_has_edit_at_target_line {
598 return no_change;
599 }
600
601 let src_line = if is_replacement {
602 match &src_edit_loc.as_ref().unwrap().patch_line {
603 PatchLine::Deletion(s) => s.clone(),
604 _ => return no_change,
605 }
606 } else {
607 // Pure insertion: source line is empty
608 String::new()
609 };
610
611 // Don't process if source and target are the same
612 if src_line == tgt_line {
613 return no_change;
614 }
615
616 // Tokenize both lines
617 let src_tokens = tokenize(&src_line);
618 let tgt_tokens = tokenize(&tgt_line);
619
620 // Use similar to get diff operations
621 let diff = TextDiff::from_slices(&src_tokens, &tgt_tokens);
622
623 // Build weights for each possible split position
624 let mut position_weights: Vec<u32> = Vec::new();
625
626 // Simulate the edit process to collect weights for all possible split positions
627 {
628 let mut current_text = String::new();
629
630 for op in diff.ops() {
631 match op.tag() {
632 DiffTag::Equal => {
633 for i in op.old_range() {
634 current_text.push_str(src_tokens[i]);
635 }
636 }
637 DiffTag::Replace => {
638 let ins: String = op.new_range().map(|i| tgt_tokens[i]).collect();
639 let del: String = op.old_range().map(|i| src_tokens[i]).collect();
640
641 // For insertion part
642 for ch in ins.chars() {
643 current_text.push(ch);
644 let weight = position_weight(¤t_text, current_text.len());
645 position_weights.push(weight);
646 }
647
648 // For deletion part (we're "untyping" from source)
649 for _ in del.chars() {
650 // Weight deletions lower as they represent removing text
651 position_weights.push(2);
652 }
653 }
654 DiffTag::Insert => {
655 let ins: String = op.new_range().map(|i| tgt_tokens[i]).collect();
656 for ch in ins.chars() {
657 current_text.push(ch);
658 let weight = position_weight(¤t_text, current_text.len());
659 position_weights.push(weight);
660 }
661 }
662 DiffTag::Delete => {
663 let del: String = op.old_range().map(|i| src_tokens[i]).collect();
664 for _ in del.chars() {
665 // Weight deletions lower
666 position_weights.push(2);
667 }
668 }
669 }
670 }
671 }
672
673 // Use weighted selection to choose split index
674 if position_weights.is_empty() {
675 return no_change;
676 }
677 let split_index = weighted_select(&position_weights, seed);
678
679 let mut edit_index = 0usize;
680 let mut new_src = String::new();
681 let mut split_found = false;
682 let mut last_old_end = 0usize;
683
684 for op in diff.ops() {
685 match op.tag() {
686 DiffTag::Equal => {
687 for i in op.old_range() {
688 new_src.push_str(src_tokens[i]);
689 }
690 last_old_end = op.old_range().end;
691 }
692 DiffTag::Replace => {
693 // Handle replace as delete + insert
694 let del: String = op.old_range().map(|i| src_tokens[i]).collect();
695 let ins: String = op.new_range().map(|i| tgt_tokens[i]).collect();
696 let repl_len = del.len() + ins.len();
697 if edit_index + repl_len >= split_index {
698 // Split within this replace operation
699 let offset = split_index - edit_index;
700 if offset < ins.len() {
701 let safe_offset = floor_char_boundary(&ins, offset);
702 new_src.push_str(&ins[..safe_offset]);
703 } else {
704 new_src.push_str(&ins);
705 let del_offset = offset - ins.len();
706 let safe_del_offset = floor_char_boundary(&del, del_offset.min(del.len()));
707 new_src.push_str(&del[..safe_del_offset]);
708 }
709 split_found = true;
710 last_old_end = op.old_range().end;
711 break;
712 } else {
713 edit_index += repl_len;
714 new_src.push_str(&ins);
715 last_old_end = op.old_range().end;
716 }
717 }
718 DiffTag::Insert => {
719 let repl: String = op.new_range().map(|i| tgt_tokens[i]).collect();
720 if edit_index + repl.len() >= split_index {
721 let offset = split_index - edit_index;
722 let safe_offset = floor_char_boundary(&repl, offset);
723 new_src.push_str(&repl[..safe_offset]);
724 split_found = true;
725 break;
726 } else {
727 edit_index += repl.len();
728 new_src.push_str(&repl);
729 }
730 }
731 DiffTag::Delete => {
732 let repl: String = op.old_range().map(|i| src_tokens[i]).collect();
733 if edit_index + repl.len() >= split_index {
734 let offset = split_index - edit_index;
735 let safe_offset = floor_char_boundary(&repl, offset);
736 new_src.push_str(&repl[..safe_offset]);
737 split_found = true;
738 last_old_end = op.old_range().start + safe_offset.min(op.old_range().len());
739 break;
740 } else {
741 edit_index += repl.len();
742 new_src.push_str(&repl);
743 last_old_end = op.old_range().end;
744 }
745 }
746 }
747 }
748
749 if !split_found {
750 return no_change;
751 }
752
753 // Calculate cursor position
754 let cursor = CursorPosition {
755 file: tgt_edit_loc.filename.clone(),
756 line: if is_replacement {
757 src_edit_loc.as_ref().unwrap().source_line_number
758 } else {
759 tgt_edit_loc.target_line_number
760 },
761 column: new_src.len() + 1,
762 };
763
764 // Add remainder of source if similar enough to target remainder
765 let remainder_src: String = (last_old_end..src_tokens.len())
766 .map(|i| src_tokens[i])
767 .collect();
768 let remainder_tgt: String = (last_old_end..tgt_tokens.len())
769 .filter_map(|i| tgt_tokens.get(i).copied())
770 .collect();
771
772 let ratio = fuzzy_ratio(&remainder_src, &remainder_tgt);
773 if ratio > 35 {
774 new_src.push_str(&remainder_src);
775 }
776
777 if new_src.trim().is_empty() {
778 return no_change;
779 }
780
781 if new_src == src_line {
782 return no_change;
783 }
784
785 // Build new source patch with the intermediate line
786 let mut new_src_patch = src_patch;
787 if is_replacement {
788 // For replacements, insert after the deletion line
789 let src_loc = src_edit_loc.as_ref().unwrap();
790 if let Some(hunk) = new_src_patch.hunks.get_mut(src_loc.hunk_index) {
791 hunk.lines.insert(
792 src_loc.line_index_within_hunk + 1,
793 PatchLine::Addition(new_src.clone()),
794 );
795 hunk.new_count += 1;
796 }
797 } else {
798 // For pure insertions, insert after the last edit in source patch
799 // This imitates human typing - the intermediate content is what the user is currently typing
800 let last_src_edit = locate_edited_line(&new_src_patch, -1);
801
802 if let Some(src_loc) = last_src_edit {
803 // Insert after the last edit in source
804 if let Some(hunk) = new_src_patch.hunks.get_mut(src_loc.hunk_index) {
805 hunk.lines.insert(
806 src_loc.line_index_within_hunk + 1,
807 PatchLine::Addition(new_src.clone()),
808 );
809 hunk.new_count += 1;
810 }
811 } else {
812 // Source patch is empty or has incompatible hunk structure, create a new hunk based on target
813 if let Some(tgt_hunk) = tgt_patch.hunks.get(tgt_edit_loc.hunk_index) {
814 let mut new_hunk = tgt_hunk.clone();
815 // Replace the full addition with the partial one
816 new_hunk.lines.clear();
817 for (i, line) in tgt_hunk.lines.iter().enumerate() {
818 if i == tgt_edit_loc.line_index_within_hunk {
819 new_hunk.lines.push(PatchLine::Addition(new_src.clone()));
820 } else {
821 match line {
822 PatchLine::Addition(_) => {
823 // Skip other additions from target
824 }
825 _ => new_hunk.lines.push(line.clone()),
826 }
827 }
828 }
829 new_hunk.new_count = new_hunk.old_count + 1;
830 new_src_patch.hunks.push(new_hunk);
831 // Copy header from target if source doesn't have one
832 if new_src_patch.header.is_empty() {
833 new_src_patch.header = tgt_patch.header.clone();
834 }
835 }
836 }
837 }
838
839 // Build new target patch with the intermediate line as deletion
840 let mut new_tgt_patch = tgt_patch;
841 if let Some(hunk) = new_tgt_patch.hunks.get_mut(tgt_edit_loc.hunk_index) {
842 hunk.lines.insert(
843 tgt_edit_loc.line_index_within_hunk,
844 PatchLine::Deletion(new_src),
845 );
846 hunk.old_count += 1;
847 }
848
849 (
850 new_src_patch.to_string(),
851 new_tgt_patch.to_string(),
852 Some(cursor),
853 )
854}
855
856/// Locate the end of the last edit in a patch.
857fn locate_end_of_last_edit(patch: &Patch) -> Option<CursorPosition> {
858 let loc = locate_edited_line(patch, -1)?;
859
860 let (line, col) = match &loc.patch_line {
861 PatchLine::Addition(content) => (loc.target_line_number, content.len()),
862 PatchLine::Deletion(_) => (loc.target_line_number, 1),
863 _ => return None,
864 };
865
866 Some(CursorPosition {
867 file: loc.filename,
868 line,
869 column: col,
870 })
871}
872
873/// Locate the beginning of the first edit in a patch.
874fn locate_beginning_of_first_edit(patch: &Patch) -> Option<CursorPosition> {
875 let loc = locate_edited_line(patch, 0)?;
876
877 let hunk = patch.hunks.get(loc.hunk_index)?;
878 let column = if loc.line_index_within_hunk > 0 {
879 if let Some(prev_line) = hunk.lines.get(loc.line_index_within_hunk - 1) {
880 let content = match prev_line {
881 PatchLine::Context(s) | PatchLine::Addition(s) | PatchLine::Deletion(s) => s,
882 _ => return None,
883 };
884 content.len().max(1) - 1
885 } else {
886 0
887 }
888 } else {
889 0
890 };
891
892 let line = loc.target_line_number.saturating_sub(1).max(1);
893
894 Some(CursorPosition {
895 file: loc.filename,
896 line,
897 column,
898 })
899}
900
901/// Sample cursor position according to the following rules:
902/// 1. 50% chance of cursor being at the end of the source patch
903/// 2. 50% chance of cursor being at the beginning of the target patch
904pub fn sample_cursor_position(patch: &Patch, split_commit: &SplitCommit) -> Option<CursorPosition> {
905 // Try end of history first
906 let src_patch = Patch::parse_unified_diff(&split_commit.source_patch);
907 if let Some(cursor) = locate_end_of_last_edit(&src_patch) {
908 return Some(cursor);
909 }
910
911 // Try beginning of target
912 let tgt_patch = Patch::parse_unified_diff(&split_commit.target_patch);
913 if let Some(cursor) = locate_beginning_of_first_edit(&tgt_patch) {
914 return Some(cursor);
915 }
916
917 // Fallback: use the original patch
918 locate_end_of_last_edit(patch)
919}
920
921/// Get cursor excerpt from the patches.
922///
923/// This extracts the lines around the cursor position with a cursor marker.
924pub fn get_cursor_excerpt(
925 cursor: &CursorPosition,
926 source_patch: &str,
927 target_patch: &str,
928) -> Option<String> {
929 let mut excerpt_lines: Vec<String> = Vec::new();
930 let mut excerpt_first_line: usize = 0;
931
932 // Search in the last hunk of source patch
933 let src = Patch::parse_unified_diff(source_patch);
934 if let Some(loc) = locate_edited_line(&src, -1) {
935 if loc.filename == cursor.file && loc.target_line_number == cursor.line {
936 if let Some(hunk) = src.hunks.get(loc.hunk_index) {
937 excerpt_first_line = hunk.new_start as usize;
938 for line in &hunk.lines {
939 match line {
940 PatchLine::Addition(s) | PatchLine::Context(s) => {
941 excerpt_lines.push(s.clone());
942 }
943 _ => {}
944 }
945 }
946 // If hunk only has deletions (file deletion), include deletion lines
947 if excerpt_lines.is_empty() {
948 excerpt_first_line = hunk.old_start as usize;
949 for line in &hunk.lines {
950 match line {
951 PatchLine::Deletion(s) => {
952 excerpt_lines.push(s.clone());
953 }
954 _ => {}
955 }
956 }
957 }
958 }
959 }
960 }
961
962 // Search in target patch if not found
963 if excerpt_lines.is_empty() {
964 let tgt = Patch::parse_unified_diff(target_patch);
965 // Search all hunks for the cursor file, not just the first edit's hunk
966 for hunk in &tgt.hunks {
967 if hunk.filename == cursor.file {
968 excerpt_first_line = hunk.new_start as usize;
969 // First try to collect deletions and context (what exists before edits)
970 for line in &hunk.lines {
971 match line {
972 PatchLine::Deletion(s) | PatchLine::Context(s) => {
973 excerpt_lines.push(s.clone());
974 }
975 _ => {}
976 }
977 }
978 // If hunk only has additions (no deletions/context), include all lines
979 // This handles cases like adding to an empty file or section
980 if excerpt_lines.is_empty() {
981 for line in &hunk.lines {
982 match line {
983 PatchLine::Addition(s)
984 | PatchLine::Deletion(s)
985 | PatchLine::Context(s) => {
986 excerpt_lines.push(s.clone());
987 }
988 _ => {}
989 }
990 }
991 }
992 if !excerpt_lines.is_empty() {
993 break;
994 }
995 }
996 }
997 }
998
999 // Also search source patch hunks if still not found (for fallback cursor case)
1000 if excerpt_lines.is_empty() {
1001 for hunk in &src.hunks {
1002 if hunk.filename == cursor.file {
1003 excerpt_first_line = hunk.new_start as usize;
1004 for line in &hunk.lines {
1005 match line {
1006 PatchLine::Addition(s) | PatchLine::Context(s) => {
1007 excerpt_lines.push(s.clone());
1008 }
1009 _ => {}
1010 }
1011 }
1012 // If hunk only has deletions, include deletion lines
1013 if excerpt_lines.is_empty() {
1014 excerpt_first_line = hunk.old_start as usize;
1015 for line in &hunk.lines {
1016 match line {
1017 PatchLine::Deletion(s) => {
1018 excerpt_lines.push(s.clone());
1019 }
1020 _ => {}
1021 }
1022 }
1023 }
1024 if !excerpt_lines.is_empty() {
1025 break;
1026 }
1027 }
1028 }
1029 }
1030
1031 if excerpt_lines.is_empty() {
1032 return None;
1033 }
1034
1035 // Add cursor marker
1036 for (i, line) in excerpt_lines.iter_mut().enumerate() {
1037 let line_num = excerpt_first_line + i;
1038 if line_num == cursor.line {
1039 let col = cursor.column.min(line.len());
1040 // Ensure we split at a valid UTF-8 character boundary
1041 let col = if line.is_char_boundary(col) {
1042 col
1043 } else {
1044 // Find the nearest valid character boundary
1045 (0..=col)
1046 .rev()
1047 .find(|&i| line.is_char_boundary(i))
1048 .unwrap_or(0)
1049 };
1050 let (before, after) = line.split_at(col);
1051 *line = format!("{}<|user_cursor|>{}", before, after);
1052 break;
1053 }
1054 }
1055
1056 Some(excerpt_lines.join("\n"))
1057}
1058
1059#[cfg(test)]
1060mod tests {
1061 use std::path::Path;
1062
1063 use edit_prediction::example_spec::ExampleSpec;
1064
1065 use super::*;
1066
1067 #[test]
1068 fn test_tokenize() {
1069 let tokens = tokenize("hello world");
1070 assert_eq!(tokens, vec!["hello", " ", "world"]);
1071
1072 let tokens = tokenize("foo_bar123 + baz");
1073 assert_eq!(tokens, vec!["foo_bar123", " ", "+", " ", "baz"]);
1074
1075 let tokens = tokenize("print(\"hello\")");
1076 assert_eq!(tokens, vec!["print", "(", "\"", "hello", "\"", ")"]);
1077
1078 let tokens = tokenize("hello_world");
1079 assert_eq!(tokens, vec!["hello_world"]);
1080
1081 let tokens = tokenize("fn();");
1082 assert_eq!(tokens, vec!["fn", "(", ")", ";"]);
1083 }
1084
1085 #[test]
1086 fn test_fuzzy_ratio() {
1087 assert_eq!(fuzzy_ratio("hello", "hello"), 100);
1088 assert_eq!(fuzzy_ratio("", ""), 100);
1089 assert!(fuzzy_ratio("hello", "world") < 50);
1090 assert!(fuzzy_ratio("hello world", "hello worl") > 80);
1091 }
1092
1093 #[test]
1094 fn test_split_ordered_commit() {
1095 let commit = r#"// First change
1096--- a/test.rs
1097+++ b/test.rs
1098@@ -1,3 +1,4 @@
1099 fn main() {
1100+ println!("hello");
1101+ println!("world");
1102 }
1103"#;
1104 let patch = Patch::parse_unified_diff(commit);
1105 let stats = patch.stats();
1106 assert_eq!(stats.added, 2);
1107
1108 let (source, target) = split_ordered_commit(commit, 1);
1109
1110 // Source should have 1 addition
1111 let src_patch = Patch::parse_unified_diff(&source);
1112 assert_eq!(src_patch.stats().added, 1);
1113
1114 // Target should have 1 addition
1115 let tgt_patch = Patch::parse_unified_diff(&target);
1116 assert_eq!(tgt_patch.stats().added, 1);
1117 }
1118
1119 #[test]
1120 fn test_split_ordered_commit_with_deletions() {
1121 let commit = r#"// Change
1122--- a/test.rs
1123+++ b/test.rs
1124@@ -1,3 +1,3 @@
1125 fn main() {
1126- println!("old");
1127+ println!("new");
1128 }
1129"#;
1130 let patch = Patch::parse_unified_diff(commit);
1131 let stats = patch.stats();
1132 assert_eq!(stats.added, 1);
1133 assert_eq!(stats.removed, 1);
1134
1135 // Split at position 1 (after the deletion)
1136 let (source, target) = split_ordered_commit(commit, 1);
1137
1138 let src_patch = Patch::parse_unified_diff(&source);
1139 let tgt_patch = Patch::parse_unified_diff(&target);
1140
1141 // Source should have the deletion
1142 assert_eq!(src_patch.stats().removed, 1);
1143 // Target should have the addition
1144 assert_eq!(tgt_patch.stats().added, 1);
1145 }
1146
1147 #[test]
1148 fn test_generate_evaluation_example() {
1149 let commit = r#"commit abc123
1150Author: Test <test@example.com>
1151Date: Mon Jan 1 00:00:00 2024
1152
1153 Test commit
1154
1155////////////////////////////////////////////////////////////////////////////////
1156// Add greeting
1157////////////////////////////////////////////////////////////////////////////////
1158--- a/test.rs
1159+++ b/test.rs
1160@@ -1,3 +1,5 @@
1161 fn main() {
1162+ println!("hello");
1163+ println!("world");
1164 }
1165"#;
1166
1167 let result = generate_evaluation_example_from_ordered_commit(
1168 commit,
1169 "https://github.com/test/repo",
1170 "abc123",
1171 Some(SplitPoint::Fraction(0.5)),
1172 Some(42),
1173 None,
1174 );
1175
1176 assert!(result.is_ok());
1177 let case = result.unwrap();
1178 assert_eq!(case.repository_url, "https://github.com/test/repo");
1179 assert_eq!(case.revision, "abc123~1");
1180 assert!(!case.edit_history.is_empty());
1181 }
1182
1183 #[test]
1184 fn test_generate_evaluation_example_reproducible() {
1185 let commit = r#"////////////////////////////////////////////////////////////////////////////////
1186// Add greeting
1187////////////////////////////////////////////////////////////////////////////////
1188--- a/test.rs
1189+++ b/test.rs
1190@@ -1,3 +1,5 @@
1191 fn main() {
1192+ println!("hello");
1193+ println!("world");
1194 }
1195"#;
1196
1197 // Run twice with the same seed
1198 let result1 = generate_evaluation_example_from_ordered_commit(
1199 commit,
1200 "https://github.com/test/repo",
1201 "abc123",
1202 Some(SplitPoint::Fraction(0.5)),
1203 Some(12345),
1204 None,
1205 )
1206 .unwrap();
1207
1208 let result2 = generate_evaluation_example_from_ordered_commit(
1209 commit,
1210 "https://github.com/test/repo",
1211 "abc123",
1212 Some(SplitPoint::Fraction(0.5)),
1213 Some(12345),
1214 None,
1215 )
1216 .unwrap();
1217
1218 // Results should be identical
1219 assert_eq!(result1.edit_history, result2.edit_history);
1220 assert_eq!(result1.expected_patches, result2.expected_patches);
1221 assert_eq!(result1.cursor_position, result2.cursor_position);
1222 }
1223
1224 #[test]
1225 fn test_cursor_position_display() {
1226 let cursor = CursorPosition {
1227 file: "src/main.rs".to_string(),
1228 line: 42,
1229 column: 10,
1230 };
1231 assert_eq!(cursor.to_string(), "src/main.rs:42:10");
1232 }
1233
1234 #[test]
1235 fn test_imitate_human_edits_no_change_when_no_replacement() {
1236 // Source and target patches that don't form a replacement pattern
1237 let source = r#"--- a/test.rs
1238+++ b/test.rs
1239@@ -1,3 +1,4 @@
1240 fn main() {
1241+ println!("hello");
1242 }
1243"#;
1244 let target = r#"--- a/test.rs
1245+++ b/test.rs
1246@@ -1,3 +1,4 @@
1247 fn main() {
1248+ println!("world");
1249 }
1250"#;
1251
1252 let (new_src, new_tgt, cursor) = imitate_human_edits(source, target, 42);
1253
1254 // Should return unchanged when not a replacement pattern
1255 assert_eq!(new_src, source);
1256 assert_eq!(new_tgt, target);
1257 assert!(cursor.is_none());
1258 }
1259
1260 #[test]
1261 fn test_split_point_fraction() {
1262 let commit = r#"// Change
1263--- a/test.rs
1264+++ b/test.rs
1265@@ -1,5 +1,10 @@
1266 fn main() {
1267+ line1();
1268+ line2();
1269+ line3();
1270+ line4();
1271+ line5();
1272 }
1273"#;
1274
1275 // Split at 20% should give first edit in source
1276 let result = generate_evaluation_example_from_ordered_commit(
1277 commit,
1278 "",
1279 "hash",
1280 Some(SplitPoint::Fraction(0.2)),
1281 Some(1),
1282 None,
1283 );
1284
1285 assert!(result.is_ok());
1286 let case = result.unwrap();
1287
1288 // Source should have some edits
1289 let src_patch = Patch::parse_unified_diff(&case.edit_history);
1290 assert!(src_patch.stats().added > 0);
1291 }
1292
1293 #[test]
1294 fn test_split_point_index() {
1295 let commit = r#"// Change
1296--- a/test.rs
1297+++ b/test.rs
1298@@ -1,5 +1,10 @@
1299 fn main() {
1300+ line1();
1301+ line2();
1302+ line3();
1303+ line4();
1304+ line5();
1305 }
1306"#;
1307
1308 // Split at index 2 should give first 2 edits in source
1309 // With pure insertion handling, source gets 2 original + 1 partial = 3 additions
1310 let result = generate_evaluation_example_from_ordered_commit(
1311 commit,
1312 "",
1313 "hash",
1314 Some(SplitPoint::Index(2)),
1315 Some(1),
1316 None,
1317 );
1318
1319 assert!(result.is_ok());
1320 let case = result.unwrap();
1321
1322 let src_patch = Patch::parse_unified_diff(&case.edit_history);
1323 // Pure insertion adds a partial line, so we expect 3 (2 original + 1 partial)
1324 assert_eq!(src_patch.stats().added, 3);
1325 }
1326
1327 #[test]
1328 fn test_cursor_excerpt_contains_marker() {
1329 let commit = r#"////////////////////////////////////////////////////////////////////////////////
1330// Add code
1331////////////////////////////////////////////////////////////////////////////////
1332--- a/test.rs
1333+++ b/test.rs
1334@@ -1,3 +1,5 @@
1335 fn main() {
1336+ println!("hello");
1337+ println!("world");
1338 }
1339"#;
1340
1341 let result = generate_evaluation_example_from_ordered_commit(
1342 commit,
1343 "",
1344 "hash",
1345 Some(SplitPoint::Fraction(0.5)),
1346 Some(42),
1347 None,
1348 )
1349 .unwrap();
1350
1351 // Cursor excerpt should contain the cursor marker
1352 assert!(
1353 result.cursor_position.contains("<|user_cursor|>"),
1354 "Cursor excerpt should contain marker: {}",
1355 result.cursor_position
1356 );
1357 }
1358
1359 #[test]
1360 fn test_evaluation_case_json_serialization() {
1361 let case = ExampleSpec {
1362 name: "test-abc123".to_string(),
1363 repository_url: "https://github.com/test/repo".to_string(),
1364 revision: "abc123~1".to_string(),
1365 edit_history: "patch1".to_string(),
1366 cursor_path: Path::new("file.rs").into(),
1367 cursor_position: "some code<|user_cursor|>".to_string(),
1368 expected_patches: vec!["patch".to_string()],
1369 tags: vec![],
1370 reasoning: None,
1371 uncommitted_diff: String::new(),
1372 rejected_patch: None,
1373
1374 telemetry: None,
1375 human_feedback: Vec::new(),
1376 rating: None,
1377 };
1378
1379 let json = serde_json::to_string(&case).unwrap();
1380 let deserialized: ExampleSpec = serde_json::from_str(&json).unwrap();
1381
1382 assert_eq!(case.repository_url, deserialized.repository_url);
1383 assert_eq!(case.revision, deserialized.revision);
1384 assert_eq!(case.cursor_position, deserialized.cursor_position);
1385 }
1386
1387 #[test]
1388 fn test_empty_commit_returns_error() {
1389 let commit = "";
1390
1391 let result = generate_evaluation_example_from_ordered_commit(
1392 commit,
1393 "",
1394 "hash",
1395 Some(SplitPoint::Fraction(0.5)),
1396 Some(1),
1397 None,
1398 );
1399
1400 assert!(result.is_err());
1401 }
1402
1403 #[test]
1404 fn test_header_filtering() {
1405 let commit = r#"commit abc123
1406Author: Test
1407Date: Today
1408
1409 Message
1410
1411diff --git a/test.rs b/test.rs
1412index 123..456 789
1413////////////////////////////////////////////////////////////////////////////////
1414// First group
1415////////////////////////////////////////////////////////////////////////////////
1416--- a/test.rs
1417+++ b/test.rs
1418@@ -1,3 +1,4 @@
1419 fn main() {
1420+ code();
1421 }
1422"#;
1423
1424 let result = generate_evaluation_example_from_ordered_commit(
1425 commit,
1426 "",
1427 "hash",
1428 Some(SplitPoint::Index(1)),
1429 Some(1),
1430 None,
1431 );
1432
1433 assert!(result.is_ok());
1434 let case = result.unwrap();
1435
1436 // The edit history should contain the group header (// lines)
1437 // but not the commit metadata
1438 assert!(!case.edit_history.contains("Author:"));
1439 assert!(!case.edit_history.contains("Date:"));
1440 }
1441
1442 #[test]
1443 fn test_position_weight() {
1444 // High weight positions (natural pause points)
1445 assert_eq!(position_weight("foo(", 4), 10); // After '('
1446 assert_eq!(position_weight("a, b", 2), 10); // After ','
1447 assert_eq!(position_weight("x;", 2), 10); // After ';'
1448 assert_eq!(position_weight("a: b", 2), 10); // After ':'
1449 assert_eq!(position_weight("[", 1), 10); // After '['
1450 assert_eq!(position_weight("{", 1), 10); // After '{'
1451
1452 // High weight for closing brackets
1453 assert_eq!(position_weight("foo)", 4), 8); // After ')'
1454 assert_eq!(position_weight("]", 1), 8); // After ']'
1455 assert_eq!(position_weight("}", 1), 8); // After '}'
1456
1457 // High weight at end of identifier
1458 assert_eq!(position_weight("foo ", 3), 8); // End of 'foo' before space
1459 assert_eq!(position_weight("bar(", 3), 8); // End of 'bar' before '('
1460
1461 // Medium weight for operators
1462 assert_eq!(position_weight("a + b", 3), 5); // After '+'
1463 assert_eq!(position_weight("x.", 2), 5); // After '.'
1464 assert_eq!(position_weight("a=b", 2), 5); // After '='
1465
1466 // Medium weight for whitespace
1467 assert_eq!(position_weight("a ", 2), 6); // After space
1468
1469 // Low weight mid-identifier
1470 assert_eq!(position_weight("foobar", 3), 1); // Mid-identifier 'foo|bar'
1471
1472 // Edge cases
1473 assert_eq!(position_weight("", 0), 1); // Empty string
1474 assert_eq!(position_weight("a", 0), 1); // Position 0
1475 }
1476
1477 #[test]
1478 fn test_weighted_select() {
1479 // Test that weighted selection returns correct indices
1480 let weights = vec![1, 10, 1];
1481
1482 // With total weight 12, seed 0 should select index 0
1483 // seed 0 % 12 = 0, cumulative: 1 at idx 0, so returns 0
1484 assert_eq!(weighted_select(&weights, 0), 0);
1485
1486 // seed 1 % 12 = 1, cumulative: 1 at idx 0 (1 < 1 is false), 11 at idx 1 (1 < 11 is true)
1487 assert_eq!(weighted_select(&weights, 1), 1);
1488
1489 // seed 10 % 12 = 10, cumulative: 1, 11 at idx 1 (10 < 11 is true)
1490 assert_eq!(weighted_select(&weights, 10), 1);
1491
1492 // seed 11 % 12 = 11, cumulative: 1, 11 at idx 1 (11 < 11 is false), 12 at idx 2 (11 < 12 is true)
1493 assert_eq!(weighted_select(&weights, 11), 2);
1494
1495 // Empty weights should return 0
1496 let empty: Vec<u32> = vec![];
1497 assert_eq!(weighted_select(&empty, 42), 0);
1498
1499 // Single weight should always return index 0
1500 let single = vec![10];
1501 assert_eq!(weighted_select(&single, 0), 0);
1502 assert_eq!(weighted_select(&single, 100), 0);
1503 }
1504
1505 #[test]
1506 fn test_weighted_split_prefers_natural_boundaries() {
1507 // Test that with different seeds, weighted selection tends to prefer
1508 // positions after punctuation over mid-identifier positions
1509 let text_with_punctuation = "foo(bar, baz)";
1510 let text_mid_identifier = "foobar";
1511
1512 // Position after '(' should have high weight
1513 let weight_after_paren = position_weight(text_with_punctuation, 4);
1514 // Position after ',' should have high weight
1515 let weight_after_comma = position_weight(text_with_punctuation, 8);
1516 // Position mid-identifier should have low weight
1517 let weight_mid_ident = position_weight(text_mid_identifier, 3);
1518
1519 assert!(
1520 weight_after_paren > weight_mid_ident,
1521 "After '(' ({}) should be weighted higher than mid-identifier ({})",
1522 weight_after_paren,
1523 weight_mid_ident
1524 );
1525 assert!(
1526 weight_after_comma > weight_mid_ident,
1527 "After ',' ({}) should be weighted higher than mid-identifier ({})",
1528 weight_after_comma,
1529 weight_mid_ident
1530 );
1531 }
1532
1533 #[test]
1534 fn test_imitate_human_edits_pure_insertion() {
1535 // Source patch is empty (no edits yet)
1536 // Target patch has a pure insertion (adding a new line)
1537 let source = r#"--- a/test.rs
1538+++ b/test.rs
1539@@ -1,2 +1,2 @@
1540 fn main() {
1541 }
1542"#;
1543 let target = r#"--- a/test.rs
1544+++ b/test.rs
1545@@ -1,2 +1,3 @@
1546 fn main() {
1547+ println!("debug");
1548 }
1549"#;
1550
1551 let (new_src, new_tgt, cursor) = imitate_human_edits(source, target, 42);
1552
1553 // Should have transformed the patches
1554 assert_ne!(
1555 new_src, source,
1556 "Source should be modified for pure insertion"
1557 );
1558 assert_ne!(
1559 new_tgt, target,
1560 "Target should be modified for pure insertion"
1561 );
1562 assert!(cursor.is_some(), "Cursor should be set");
1563
1564 // Source should now have a partial addition
1565 let src_patch = Patch::parse_unified_diff(&new_src);
1566 assert!(
1567 src_patch.stats().added > 0,
1568 "Source should have added lines"
1569 );
1570
1571 // Target should have both a deletion (of partial) and addition (of full)
1572 let tgt_patch = Patch::parse_unified_diff(&new_tgt);
1573 assert!(
1574 tgt_patch.stats().removed > 0,
1575 "Target should have removed lines (partial)"
1576 );
1577 assert!(
1578 tgt_patch.stats().added > 0,
1579 "Target should have added lines (full)"
1580 );
1581
1582 // The cursor should be in test.rs
1583 let cursor = cursor.unwrap();
1584 assert_eq!(cursor.file, "test.rs");
1585 }
1586
1587 #[test]
1588 fn test_imitate_human_edits_pure_insertion_empty_source() {
1589 // Source patch has no hunks at all
1590 let source = "";
1591 let target = r#"--- a/test.rs
1592+++ b/test.rs
1593@@ -1,2 +1,3 @@
1594 fn main() {
1595+ println!("hello");
1596 }
1597"#;
1598
1599 let (new_src, _new_tgt, cursor) = imitate_human_edits(source, target, 123);
1600
1601 // Should have created a source patch with partial insertion
1602 assert!(!new_src.is_empty(), "Source should not be empty");
1603 assert!(cursor.is_some(), "Cursor should be set");
1604
1605 let src_patch = Patch::parse_unified_diff(&new_src);
1606 assert!(
1607 src_patch.stats().added > 0,
1608 "Source should have added lines"
1609 );
1610 }
1611
1612 #[test]
1613 fn test_imitate_human_edits_pure_insertion_intermediate_content() {
1614 // Verify the actual intermediate content is a realistic partial typing state
1615 let source = "";
1616 let target = r#"--- a/test.rs
1617+++ b/test.rs
1618@@ -1,2 +1,3 @@
1619 fn main() {
1620+ println!("hello world");
1621 }
1622"#;
1623
1624 // Test with multiple seeds to see different split points
1625 let mut found_partial = false;
1626 for seed in 1..=50 {
1627 let (new_src, new_tgt, cursor) = imitate_human_edits(source, target, seed);
1628
1629 if cursor.is_some() {
1630 let src_patch = Patch::parse_unified_diff(&new_src);
1631 let tgt_patch = Patch::parse_unified_diff(&new_tgt);
1632
1633 // Find the added line in source
1634 for hunk in &src_patch.hunks {
1635 for line in &hunk.lines {
1636 if let PatchLine::Addition(content) = line {
1637 // The partial line should be a prefix of the full line
1638 let full_line = " println!(\"hello world\");";
1639 if content != full_line && full_line.starts_with(content) {
1640 found_partial = true;
1641
1642 // Verify target has the partial as deletion
1643 let mut has_deletion = false;
1644 for tgt_hunk in &tgt_patch.hunks {
1645 for tgt_line in &tgt_hunk.lines {
1646 if let PatchLine::Deletion(del_content) = tgt_line {
1647 if del_content == content {
1648 has_deletion = true;
1649 }
1650 }
1651 }
1652 }
1653 assert!(
1654 has_deletion,
1655 "Target should have deletion of partial line"
1656 );
1657 }
1658 }
1659 }
1660 }
1661 }
1662 }
1663
1664 assert!(
1665 found_partial,
1666 "At least one seed should produce a partial intermediate state"
1667 );
1668 }
1669
1670 #[test]
1671 fn test_imitate_human_edits_inserts_after_last_source_edit() {
1672 // Regression test: intermediate content should appear after the last edit
1673 // in the source patch, not at the position of the first target edit.
1674 // This ensures the diff output correctly imitates human typing order.
1675 //
1676 // The bug was: when source has edits and target has a pure insertion,
1677 // the intermediate content was inserted at tgt_edit_loc.line_index_within_hunk
1678 // (position of first target edit) instead of after the last source edit.
1679 //
1680 // Source patch has edits at lines 1-4, target has a new edit at line 10
1681 // (different location to avoid the "same line" early return)
1682 let source = r#"--- a/test.py
1683+++ b/test.py
1684@@ -1,4 +1,5 @@
1685+import foo
1686 import bar
1687-import old
1688 import baz
1689+import qux
1690"#;
1691 // Target has a pure insertion at a different line (line 10, not overlapping with source)
1692 let target = r#"--- a/test.py
1693+++ b/test.py
1694@@ -10,3 +10,4 @@
1695 def main():
1696+ print("hello world")
1697 pass
1698"#;
1699
1700 // Use a seed that produces a partial result
1701 let (new_src, _new_tgt, cursor) = imitate_human_edits(source, target, 42);
1702
1703 // The function should produce a modified patch
1704 assert!(cursor.is_some(), "Should produce intermediate state");
1705
1706 let src_patch = Patch::parse_unified_diff(&new_src);
1707 let all_additions: Vec<_> = src_patch
1708 .hunks
1709 .iter()
1710 .flat_map(|h| h.lines.iter())
1711 .filter_map(|l| match l {
1712 PatchLine::Addition(s) => Some(s.as_str()),
1713 _ => None,
1714 })
1715 .collect();
1716
1717 // The intermediate content (partial 'print("hello world")') should be
1718 // the LAST addition, appearing after "+import qux" (the last source edit)
1719 let last_addition = all_additions.last().expect("Should have additions");
1720 assert!(
1721 last_addition.trim_start().starts_with("pr"),
1722 "Intermediate content should be the last addition (partial 'print'), but last was: {:?}",
1723 last_addition
1724 );
1725
1726 // Verify the original source edits are still in order before the intermediate
1727 let foo_pos = all_additions.iter().position(|s| *s == "import foo");
1728 let qux_pos = all_additions.iter().position(|s| *s == "import qux");
1729 let intermediate_pos = all_additions
1730 .iter()
1731 .position(|s| s.trim_start().starts_with("pr"));
1732
1733 assert!(foo_pos.is_some(), "Should have 'import foo'");
1734 assert!(qux_pos.is_some(), "Should have 'import qux'");
1735 assert!(
1736 intermediate_pos.is_some(),
1737 "Should have intermediate content"
1738 );
1739
1740 assert!(
1741 foo_pos < qux_pos && qux_pos < intermediate_pos,
1742 "Order should be: foo < qux < intermediate. Got foo={:?}, qux={:?}, intermediate={:?}",
1743 foo_pos,
1744 qux_pos,
1745 intermediate_pos
1746 );
1747 }
1748
1749 #[test]
1750 fn test_cursor_excerpt_with_multibyte_utf8() {
1751 // Test that cursor excerpt handles multi-byte UTF-8 characters correctly
1752 // The Chinese character '第' is 3 bytes (0..3)
1753 let cursor = CursorPosition {
1754 file: "test.md".to_string(),
1755 line: 1,
1756 column: 1, // Byte index 1 is inside '第' (bytes 0..3)
1757 };
1758
1759 let source_patch = r#"--- a/test.md
1760+++ b/test.md
1761@@ -1,1 +1,1 @@
1762+第 14 章 Flask 工作原理与机制解析**
1763"#;
1764
1765 let target_patch = "";
1766
1767 // This should not panic even though column=1 is not a char boundary
1768 let result = get_cursor_excerpt(&cursor, source_patch, target_patch);
1769
1770 // The function should handle the invalid byte index gracefully
1771 if let Some(excerpt) = result {
1772 assert!(
1773 excerpt.contains("<|user_cursor|>"),
1774 "Cursor excerpt should contain marker"
1775 );
1776 // The marker should be placed at a valid character boundary
1777 // (either at the start or after '第')
1778 }
1779 }
1780}