1//! `ep split-commit` implementation.
2//!
3//! This command generates a single evaluation example JSON object from a
4//! chronologically-ordered unified diff (a "commit").
5//!
6//! TODO: Port Python code to generate chronologically-ordered commits
7use crate::reorder_patch::{Patch, PatchLine, extract_edits, locate_edited_line};
8
9/// Find the largest valid UTF-8 char boundary at or before `index` in `s`.
10fn floor_char_boundary(s: &str, index: usize) -> usize {
11 if index >= s.len() {
12 s.len()
13 } else if s.is_char_boundary(index) {
14 index
15 } else {
16 // Find the nearest valid character boundary at or before index
17 (0..index)
18 .rev()
19 .find(|&i| s.is_char_boundary(i))
20 .unwrap_or(0)
21 }
22}
23use anyhow::{Context as _, Result};
24use clap::Args;
25use edit_prediction::example_spec::ExampleSpec;
26use rand::Rng;
27use rand::SeedableRng;
28use serde::{Deserialize, Serialize};
29use similar::{DiffTag, TextDiff};
30use std::collections::BTreeSet;
31use std::fs;
32use std::io::{self, Write};
33use std::path::Path;
34use std::path::PathBuf;
35
36/// `ep split-commit` CLI args.
37#[derive(Debug, Args, Clone)]
38pub struct SplitCommitArgs {
39 /// Split point (float 0.0-1.0 for fraction, or integer for index)
40 #[arg(long, short = 's')]
41 pub split_point: Option<String>,
42
43 /// Random seed for reproducibility
44 #[arg(long)]
45 pub seed: Option<u64>,
46
47 /// Pretty-print JSON output
48 #[arg(long, short = 'p')]
49 pub pretty: bool,
50
51 /// Number of samples to generate per commit (samples random split points)
52 #[arg(long, short = 'n')]
53 pub num_samples: Option<usize>,
54}
55
56/// Input format for annotated commits (JSON Lines).
57#[derive(Debug, Clone, Deserialize)]
58#[allow(dead_code)]
59pub struct AnnotatedCommit {
60 /// Repository path (e.g., "repos/zed")
61 pub repo: String,
62 /// Repository URL (e.g., "https://github.com/zed-industries/zed")
63 pub repo_url: String,
64 /// Commit SHA
65 pub commit_sha: String,
66 /// Chronologically reordered commit diff
67 pub reordered_commit: String,
68 /// Original commit diff
69 pub original_commit: String,
70 /// Whether diff stats match between original and reordered
71 pub diff_stats_match: bool,
72}
73
74/// Cursor position in a file.
75#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
76pub struct CursorPosition {
77 pub file: String,
78 pub line: usize,
79 pub column: usize,
80}
81
82impl std::fmt::Display for CursorPosition {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 write!(f, "{}:{}:{}", self.file, self.line, self.column)
85 }
86}
87
88/// Represents a split commit with source and target patches.
89#[derive(Debug, Clone)]
90pub struct SplitCommit {
91 pub source_patch: String,
92 pub target_patch: String,
93}
94
95/// Split point specification for evaluation generation.
96#[derive(Debug, Clone)]
97pub enum SplitPoint {
98 /// Fraction of total edits (0.0 to 1.0)
99 Fraction(f64),
100 /// Absolute index
101 Index(usize),
102}
103
104fn parse_split_point(value: &str) -> Option<SplitPoint> {
105 if value.contains('.') {
106 value.parse::<f64>().ok().map(SplitPoint::Fraction)
107 } else {
108 value.parse::<usize>().ok().map(SplitPoint::Index)
109 }
110}
111
112/// Entry point for the `ep split-commit` subcommand.
113///
114/// This runs synchronously and outputs JSON Lines (one output per input line).
115pub fn run_split_commit(
116 args: &SplitCommitArgs,
117 inputs: &[PathBuf],
118 output_path: Option<&PathBuf>,
119) -> Result<()> {
120 use std::collections::HashSet;
121 use std::io::BufRead;
122
123 let stdin_path = PathBuf::from("-");
124 let inputs = if inputs.is_empty() {
125 std::slice::from_ref(&stdin_path)
126 } else {
127 inputs
128 };
129
130 let split_point = args.split_point.as_deref().and_then(parse_split_point);
131 let mut output_lines = Vec::new();
132
133 for input_path in inputs {
134 let input: Box<dyn BufRead> = if input_path.as_os_str() == "-" {
135 Box::new(io::BufReader::new(io::stdin()))
136 } else {
137 let file = fs::File::open(input_path)
138 .with_context(|| format!("failed to open input file {}", input_path.display()))?;
139 Box::new(io::BufReader::new(file))
140 };
141
142 for (line_num, line_result) in input.lines().enumerate() {
143 let line =
144 line_result.with_context(|| format!("failed to read line {}", line_num + 1))?;
145
146 if line.trim().is_empty() {
147 continue;
148 }
149
150 let annotated: AnnotatedCommit = serde_json::from_str(&line)
151 .with_context(|| format!("failed to parse JSON at line {}", line_num + 1))?;
152
153 // Generate multiple samples if num_samples is set
154 if let Some(num_samples) = args.num_samples {
155 let mut seen_samples: HashSet<String> = HashSet::new();
156 let base_seed = args.seed.unwrap_or_else(|| rand::random());
157
158 for sample_idx in 0..num_samples {
159 let sample_seed = base_seed.wrapping_add(sample_idx as u64);
160
161 let case = generate_evaluation_example_from_ordered_commit(
162 &annotated.reordered_commit,
163 &annotated.repo_url,
164 &annotated.commit_sha,
165 None, // Use random split point for multi-sample mode
166 Some(sample_seed),
167 Some(sample_idx),
168 )
169 .with_context(|| {
170 format!(
171 "failed to generate evaluation example for commit {} at line {} (sample {})",
172 annotated.commit_sha,
173 line_num + 1,
174 sample_idx
175 )
176 })?;
177
178 let json = if args.pretty {
179 serde_json::to_string_pretty(&case)
180 } else {
181 serde_json::to_string(&case)
182 }
183 .context("failed to serialize evaluation case as JSON")?;
184
185 // Only add unique samples (different split points may produce same result)
186 if seen_samples.insert(json.clone()) {
187 output_lines.push(json);
188 }
189 }
190 } else {
191 let case = generate_evaluation_example_from_ordered_commit(
192 &annotated.reordered_commit,
193 &annotated.repo_url,
194 &annotated.commit_sha,
195 split_point.clone(),
196 args.seed,
197 None,
198 )
199 .with_context(|| {
200 format!(
201 "failed to generate evaluation example for commit {} at line {}",
202 annotated.commit_sha,
203 line_num + 1
204 )
205 })?;
206
207 let json = if args.pretty {
208 serde_json::to_string_pretty(&case)
209 } else {
210 serde_json::to_string(&case)
211 }
212 .context("failed to serialize evaluation case as JSON")?;
213
214 output_lines.push(json);
215 }
216 }
217 }
218
219 let output_content = output_lines.join("\n") + if output_lines.is_empty() { "" } else { "\n" };
220
221 if let Some(path) = output_path {
222 fs::write(path, &output_content)
223 .with_context(|| format!("failed to write output to {}", path.display()))?;
224 } else {
225 io::stdout()
226 .write_all(output_content.as_bytes())
227 .context("failed to write to stdout")?;
228 }
229
230 Ok(())
231}
232
233/// Main function to generate an evaluation example from an ordered commit.
234///
235/// # Arguments
236/// * `commit` - Chronologically ordered unified diff of the commit
237/// * `repository_url` - URL of the repository
238/// * `commit_hash` - Hash of the commit
239/// * `split_point` - Point at which the commit will be split (None for random)
240/// * `seed` - Optional seed for randomness
241/// * `sample_num` - Optional sample number for generating unique names
242pub fn generate_evaluation_example_from_ordered_commit(
243 commit: &str,
244 repository_url: &str,
245 commit_hash: &str,
246 split_point: Option<SplitPoint>,
247 seed: Option<u64>,
248 sample_num: Option<usize>,
249) -> Result<ExampleSpec> {
250 let mut rng: Box<dyn rand::RngCore> = match seed {
251 Some(seed) => Box::new(rand::rngs::StdRng::seed_from_u64(seed)),
252 None => Box::new(rand::rngs::ThreadRng::default()),
253 };
254
255 // Parse and normalize the commit
256 let mut patch = Patch::parse_unified_diff(commit);
257
258 // Filter header to only keep lines starting with "//"
259 let header_lines: Vec<&str> = patch
260 .header
261 .lines()
262 .filter(|line| line.starts_with("//"))
263 .collect();
264 patch.header = if header_lines.is_empty() {
265 String::new()
266 } else {
267 header_lines.join("\n") + "\n"
268 };
269 let commit_normalized = patch.to_string();
270
271 // Compute the split point
272 let stats = patch.stats();
273 let num_edits = stats.added + stats.removed;
274
275 anyhow::ensure!(num_edits != 0, "no edits found in commit");
276
277 let split = match split_point {
278 None => rng.random_range(1..=num_edits),
279 Some(SplitPoint::Fraction(f)) => {
280 let v = (f * num_edits as f64).floor() as usize;
281 v.min(num_edits)
282 }
283 Some(SplitPoint::Index(i)) => i.min(num_edits),
284 };
285
286 // Split the commit into source and target patches
287 let (prefix, suffix) = split_ordered_commit(&commit_normalized, split);
288
289 let mut split_commit = SplitCommit {
290 source_patch: prefix,
291 target_patch: suffix,
292 };
293
294 // Imitate human edits
295 let human_edit_seed = rng.random_range(1..=10000u64);
296 let (src_patch, tgt_patch, cursor_opt) = imitate_human_edits(
297 &split_commit.source_patch,
298 &split_commit.target_patch,
299 human_edit_seed,
300 );
301 split_commit.source_patch = src_patch;
302 split_commit.target_patch = tgt_patch;
303
304 // Sample cursor position
305 let cursor = match cursor_opt {
306 Some(c) => c,
307 None => sample_cursor_position(&patch, &split_commit)
308 .context("failed to sample cursor position")?,
309 };
310
311 // Get cursor excerpt
312 let cursor_excerpt = get_cursor_excerpt(
313 &cursor,
314 &split_commit.source_patch,
315 &split_commit.target_patch,
316 )
317 .context("failed to generate cursor excerpt")?;
318
319 // Handle edge case where split_point == 0
320 if split == 0 {
321 split_commit.target_patch = String::new();
322 }
323
324 let repo_name = repository_url
325 .trim_end_matches('/')
326 .rsplit('/')
327 .next()
328 .unwrap_or("unknown");
329 let short_sha = &commit_hash[..commit_hash.len().min(8)];
330 let name = match sample_num {
331 Some(n) => format!("{}-{}-{}", repo_name, short_sha, n),
332 None => format!("{}-{}", repo_name, short_sha),
333 };
334
335 Ok(ExampleSpec {
336 name,
337 repository_url: repository_url.to_string(),
338 revision: format!("{}~1", commit_hash),
339 edit_history: split_commit.source_patch.clone(),
340 // cursor_position: cursor.to_string(),
341 cursor_path: Path::new(&cursor.file).into(),
342 cursor_position: cursor_excerpt,
343 expected_patches: vec![split_commit.target_patch],
344 tags: vec![],
345 reasoning: None,
346 uncommitted_diff: String::new(),
347 })
348}
349
350/// Split an ordered commit into source and target commits.
351///
352/// # Arguments
353/// * `commit` - Ordered commit string
354/// * `split_pos` - Position to split the commit (number of edited lines)
355///
356/// # Returns
357/// A tuple of (source_diff, target_diff)
358pub fn split_ordered_commit(commit: &str, split_pos: usize) -> (String, String) {
359 let patch = Patch::parse_unified_diff(commit);
360 let source_edits: BTreeSet<usize> = (0..split_pos).collect();
361 let (source, target) = extract_edits(&patch, &source_edits);
362
363 let mut source_str = source.to_string();
364 let target_str = target.to_string();
365
366 // Strip last group header from the source (lines starting with "//" at the end)
367 let source_lines: Vec<&str> = source_str.lines().collect();
368 let mut end_idx = source_lines.len();
369 for i in (0..source_lines.len()).rev() {
370 if source_lines[i].starts_with("//") {
371 end_idx = i;
372 } else {
373 break;
374 }
375 }
376 if end_idx < source_lines.len() {
377 source_str = source_lines[..end_idx].join("\n");
378 if !source_str.is_empty() {
379 source_str.push('\n');
380 }
381 }
382
383 (source_str, target_str)
384}
385
386/// Tokenize text into words and non-word characters.
387fn tokenize(text: &str) -> Vec<String> {
388 let mut tokens = Vec::new();
389 let mut current = String::new();
390
391 for ch in text.chars() {
392 if ch.is_alphanumeric() {
393 current.push(ch);
394 } else if ch == '_' {
395 // Include underscore with the current word, then flush
396 current.push(ch);
397 if !current.is_empty() {
398 tokens.push(std::mem::take(&mut current));
399 }
400 } else {
401 // Punctuation or whitespace - flush current word first
402 if !current.is_empty() {
403 tokens.push(std::mem::take(&mut current));
404 }
405 // Each punctuation/whitespace is its own token
406 tokens.push(ch.to_string());
407 }
408 }
409
410 if !current.is_empty() {
411 tokens.push(current);
412 }
413
414 tokens
415}
416
417/// Calculate the weight for a split position based on the character at that position.
418///
419/// Higher weights indicate more natural pause points (e.g., after punctuation,
420/// at identifier boundaries). Lower weights indicate less natural points
421/// (e.g., mid-identifier).
422fn position_weight(text: &str, pos: usize) -> u32 {
423 if pos == 0 || pos > text.len() {
424 return 1;
425 }
426
427 let chars: Vec<char> = text.chars().collect();
428 if pos > chars.len() {
429 return 1;
430 }
431
432 // Get the character just before this position (what we just "typed")
433 let prev_char = chars[pos - 1];
434
435 // High weight: natural pause points (end of statement/argument, opening brackets)
436 if matches!(prev_char, ',' | ';' | ':' | '(' | '[' | '{') {
437 return 10;
438 }
439
440 // High weight: closing brackets (finished a group)
441 if matches!(prev_char, ')' | ']' | '}') {
442 return 8;
443 }
444
445 // Medium weight: operators and method chains
446 if matches!(
447 prev_char,
448 '.' | '+' | '-' | '*' | '/' | '=' | '<' | '>' | '&' | '|' | '!'
449 ) {
450 return 5;
451 }
452
453 // Check if we're at the end of an identifier (word char followed by non-word char)
454 let is_prev_word_char = prev_char.is_alphanumeric() || prev_char == '_';
455 let is_next_word_char =
456 pos < chars.len() && (chars[pos].is_alphanumeric() || chars[pos] == '_');
457
458 if is_prev_word_char && !is_next_word_char {
459 // End of identifier - high weight
460 return 8;
461 }
462
463 // Whitespace is a natural pause
464 if prev_char.is_whitespace() {
465 return 6;
466 }
467
468 // Mid-identifier: low weight (rare autocomplete scenarios)
469 if is_prev_word_char && is_next_word_char {
470 return 1;
471 }
472
473 // Default medium-low weight
474 3
475}
476
477/// Select a weighted random index from a list of weights.
478///
479/// Returns an index based on the weights, using the provided seed for
480/// deterministic selection.
481fn weighted_select(weights: &[u32], seed: u64) -> usize {
482 if weights.is_empty() {
483 return 0;
484 }
485
486 let total_weight: u64 = weights.iter().map(|&w| w as u64).sum();
487 if total_weight == 0 {
488 // Fallback to uniform selection if all weights are zero
489 return seed as usize % weights.len();
490 }
491
492 // Use seed to select a value in [0, total_weight)
493 let target = seed % total_weight;
494 let mut cumulative: u64 = 0;
495
496 for (idx, &weight) in weights.iter().enumerate() {
497 cumulative += weight as u64;
498 if target < cumulative {
499 return idx;
500 }
501 }
502
503 // Fallback to last index
504 weights.len() - 1
505}
506
507/// Calculate similarity ratio between two strings (0-100).
508fn fuzzy_ratio(s1: &str, s2: &str) -> u32 {
509 if s1.is_empty() && s2.is_empty() {
510 return 100;
511 }
512 if s1.is_empty() || s2.is_empty() {
513 return 0;
514 }
515
516 let diff = TextDiff::from_chars(s1, s2);
517 let matching: usize = diff
518 .ops()
519 .iter()
520 .filter_map(|op| {
521 if matches!(op.tag(), DiffTag::Equal) {
522 Some(op.new_range().len())
523 } else {
524 None
525 }
526 })
527 .sum();
528
529 let total = s1.len() + s2.len();
530 ((2 * matching * 100) / total) as u32
531}
532
533/// Imitate human edits by introducing partial line edits.
534///
535/// This function simulates how a human might incrementally type code,
536/// rather than making complete line replacements.
537pub fn imitate_human_edits(
538 source_patch: &str,
539 target_patch: &str,
540 seed: u64,
541) -> (String, String, Option<CursorPosition>) {
542 let no_change = (source_patch.to_string(), target_patch.to_string(), None);
543
544 let src_patch = Patch::parse_unified_diff(source_patch);
545 let tgt_patch = Patch::parse_unified_diff(target_patch);
546
547 if tgt_patch.hunks.is_empty() {
548 return no_change;
549 }
550
551 // Try to locate the first edit in target
552 let tgt_edit_loc = match locate_edited_line(&tgt_patch, 0) {
553 Some(loc) => loc,
554 None => return no_change,
555 };
556
557 let tgt_is_addition = matches!(tgt_edit_loc.patch_line, PatchLine::Addition(_));
558 if !tgt_is_addition {
559 return no_change;
560 }
561
562 let tgt_line = match &tgt_edit_loc.patch_line {
563 PatchLine::Addition(s) => s.clone(),
564 _ => return no_change,
565 };
566
567 // Try to locate the last edit in source
568 let src_edit_loc = locate_edited_line(&src_patch, -1);
569
570 // Check if source has ANY edit at the same line as target's first edit
571 // We need to iterate through all edits to check this
572 let src_has_edit_at_target_line = {
573 let mut found = false;
574 let mut idx = 0isize;
575 while let Some(loc) = locate_edited_line(&src_patch, idx) {
576 if loc.filename == tgt_edit_loc.filename
577 && loc.target_line_number == tgt_edit_loc.target_line_number
578 {
579 found = true;
580 break;
581 }
582 idx += 1;
583 }
584 found
585 };
586
587 // Check if this is a replacement (deletion followed by insertion on the same line)
588 // or a pure insertion (no corresponding deletion in source)
589 let is_replacement = src_edit_loc.as_ref().map_or(false, |loc| {
590 matches!(loc.patch_line, PatchLine::Deletion(_))
591 && loc.filename == tgt_edit_loc.filename
592 && loc.target_line_number == tgt_edit_loc.target_line_number
593 });
594
595 // If source has an edit at the same line but it's not a replacement (i.e., it's an addition),
596 // we shouldn't process this as a pure insertion either
597 if !is_replacement && src_has_edit_at_target_line {
598 return no_change;
599 }
600
601 let src_line = if is_replacement {
602 match &src_edit_loc.as_ref().unwrap().patch_line {
603 PatchLine::Deletion(s) => s.clone(),
604 _ => return no_change,
605 }
606 } else {
607 // Pure insertion: source line is empty
608 String::new()
609 };
610
611 // Don't process if source and target are the same
612 if src_line == tgt_line {
613 return no_change;
614 }
615
616 // Tokenize both lines
617 let src_tokens = tokenize(&src_line);
618 let tgt_tokens = tokenize(&tgt_line);
619
620 // Convert to slices for similar
621 let src_refs: Vec<&str> = src_tokens.iter().map(|s| s.as_str()).collect();
622 let tgt_refs: Vec<&str> = tgt_tokens.iter().map(|s| s.as_str()).collect();
623
624 // Use similar to get diff operations
625 let diff = TextDiff::from_slices(&src_refs, &tgt_refs);
626
627 // Build weights for each possible split position
628 let mut position_weights: Vec<u32> = Vec::new();
629
630 // Simulate the edit process to collect weights for all possible split positions
631 {
632 let mut current_text = String::new();
633
634 for op in diff.ops() {
635 match op.tag() {
636 DiffTag::Equal => {
637 for i in op.old_range() {
638 current_text.push_str(&src_tokens[i]);
639 }
640 }
641 DiffTag::Replace => {
642 let ins: String = op.new_range().map(|i| tgt_tokens[i].as_str()).collect();
643 let del: String = op.old_range().map(|i| src_tokens[i].as_str()).collect();
644
645 // For insertion part
646 for ch in ins.chars() {
647 current_text.push(ch);
648 let weight = position_weight(¤t_text, current_text.len());
649 position_weights.push(weight);
650 }
651
652 // For deletion part (we're "untyping" from source)
653 for _ in del.chars() {
654 // Weight deletions lower as they represent removing text
655 position_weights.push(2);
656 }
657 }
658 DiffTag::Insert => {
659 let ins: String = op.new_range().map(|i| tgt_tokens[i].as_str()).collect();
660 for ch in ins.chars() {
661 current_text.push(ch);
662 let weight = position_weight(¤t_text, current_text.len());
663 position_weights.push(weight);
664 }
665 }
666 DiffTag::Delete => {
667 let del: String = op.old_range().map(|i| src_tokens[i].as_str()).collect();
668 for _ in del.chars() {
669 // Weight deletions lower
670 position_weights.push(2);
671 }
672 }
673 }
674 }
675 }
676
677 // Use weighted selection to choose split index
678 if position_weights.is_empty() {
679 return no_change;
680 }
681 let split_index = weighted_select(&position_weights, seed);
682
683 let mut edit_index = 0usize;
684 let mut new_src = String::new();
685 let mut split_found = false;
686 let mut last_old_end = 0usize;
687
688 for op in diff.ops() {
689 match op.tag() {
690 DiffTag::Equal => {
691 for i in op.old_range() {
692 new_src.push_str(&src_tokens[i]);
693 }
694 last_old_end = op.old_range().end;
695 }
696 DiffTag::Replace => {
697 // Handle replace as delete + insert
698 let del: String = op.old_range().map(|i| src_tokens[i].as_str()).collect();
699 let ins: String = op.new_range().map(|i| tgt_tokens[i].as_str()).collect();
700 let repl_len = del.len() + ins.len();
701 if edit_index + repl_len >= split_index {
702 // Split within this replace operation
703 let offset = split_index - edit_index;
704 if offset < ins.len() {
705 let safe_offset = floor_char_boundary(&ins, offset);
706 new_src.push_str(&ins[..safe_offset]);
707 } else {
708 new_src.push_str(&ins);
709 let del_offset = offset - ins.len();
710 let safe_del_offset = floor_char_boundary(&del, del_offset.min(del.len()));
711 new_src.push_str(&del[..safe_del_offset]);
712 }
713 split_found = true;
714 last_old_end = op.old_range().end;
715 break;
716 } else {
717 edit_index += repl_len;
718 new_src.push_str(&ins);
719 last_old_end = op.old_range().end;
720 }
721 }
722 DiffTag::Insert => {
723 let repl: String = op.new_range().map(|i| tgt_tokens[i].as_str()).collect();
724 if edit_index + repl.len() >= split_index {
725 let offset = split_index - edit_index;
726 let safe_offset = floor_char_boundary(&repl, offset);
727 new_src.push_str(&repl[..safe_offset]);
728 split_found = true;
729 break;
730 } else {
731 edit_index += repl.len();
732 new_src.push_str(&repl);
733 }
734 }
735 DiffTag::Delete => {
736 let repl: String = op.old_range().map(|i| src_tokens[i].as_str()).collect();
737 if edit_index + repl.len() >= split_index {
738 let offset = split_index - edit_index;
739 let safe_offset = floor_char_boundary(&repl, offset);
740 new_src.push_str(&repl[..safe_offset]);
741 split_found = true;
742 last_old_end = op.old_range().start + safe_offset.min(op.old_range().len());
743 break;
744 } else {
745 edit_index += repl.len();
746 new_src.push_str(&repl);
747 last_old_end = op.old_range().end;
748 }
749 }
750 }
751 }
752
753 if !split_found {
754 return no_change;
755 }
756
757 // Calculate cursor position
758 let cursor = CursorPosition {
759 file: tgt_edit_loc.filename.clone(),
760 line: if is_replacement {
761 src_edit_loc.as_ref().unwrap().source_line_number
762 } else {
763 tgt_edit_loc.target_line_number
764 },
765 column: new_src.len() + 1,
766 };
767
768 // Add remainder of source if similar enough to target remainder
769 let remainder_src: String = (last_old_end..src_tokens.len())
770 .map(|i| src_tokens[i].as_str())
771 .collect();
772 let remainder_tgt: String = (last_old_end..tgt_tokens.len())
773 .filter_map(|i| tgt_tokens.get(i).map(|s| s.as_str()))
774 .collect();
775
776 let ratio = fuzzy_ratio(&remainder_src, &remainder_tgt);
777 if ratio > 35 {
778 new_src.push_str(&remainder_src);
779 }
780
781 if new_src.trim().is_empty() {
782 return no_change;
783 }
784
785 if new_src == src_line {
786 return no_change;
787 }
788
789 // Build new source patch with the intermediate line
790 let mut new_src_patch = src_patch;
791 if is_replacement {
792 // For replacements, insert after the deletion line
793 let src_loc = src_edit_loc.as_ref().unwrap();
794 if let Some(hunk) = new_src_patch.hunks.get_mut(src_loc.hunk_index) {
795 hunk.lines.insert(
796 src_loc.line_index_within_hunk + 1,
797 PatchLine::Addition(new_src.clone()),
798 );
799 hunk.new_count += 1;
800 }
801 } else {
802 // For pure insertions, we need to add or modify a hunk
803 // Check if the source hunk exists AND has enough lines for the target's line index
804 let can_insert_in_existing_hunk = new_src_patch
805 .hunks
806 .get(tgt_edit_loc.hunk_index)
807 .map_or(false, |hunk| {
808 tgt_edit_loc.line_index_within_hunk <= hunk.lines.len()
809 });
810
811 if can_insert_in_existing_hunk {
812 if let Some(hunk) = new_src_patch.hunks.get_mut(tgt_edit_loc.hunk_index) {
813 // Insert the partial line at the same position as target
814 hunk.lines.insert(
815 tgt_edit_loc.line_index_within_hunk,
816 PatchLine::Addition(new_src.clone()),
817 );
818 hunk.new_count += 1;
819 }
820 } else {
821 // Source patch is empty or has incompatible hunk structure, create a new hunk based on target
822 if let Some(tgt_hunk) = tgt_patch.hunks.get(tgt_edit_loc.hunk_index) {
823 let mut new_hunk = tgt_hunk.clone();
824 // Replace the full addition with the partial one
825 new_hunk.lines.clear();
826 for (i, line) in tgt_hunk.lines.iter().enumerate() {
827 if i == tgt_edit_loc.line_index_within_hunk {
828 new_hunk.lines.push(PatchLine::Addition(new_src.clone()));
829 } else {
830 match line {
831 PatchLine::Addition(_) => {
832 // Skip other additions from target
833 }
834 _ => new_hunk.lines.push(line.clone()),
835 }
836 }
837 }
838 new_hunk.new_count = new_hunk.old_count + 1;
839 new_src_patch.hunks.push(new_hunk);
840 // Copy header from target if source doesn't have one
841 if new_src_patch.header.is_empty() {
842 new_src_patch.header = tgt_patch.header.clone();
843 }
844 }
845 }
846 }
847
848 // Build new target patch with the intermediate line as deletion
849 let mut new_tgt_patch = tgt_patch;
850 if let Some(hunk) = new_tgt_patch.hunks.get_mut(tgt_edit_loc.hunk_index) {
851 hunk.lines.insert(
852 tgt_edit_loc.line_index_within_hunk,
853 PatchLine::Deletion(new_src),
854 );
855 hunk.old_count += 1;
856 }
857
858 (
859 new_src_patch.to_string(),
860 new_tgt_patch.to_string(),
861 Some(cursor),
862 )
863}
864
865/// Locate the end of the last edit in a patch.
866fn locate_end_of_last_edit(patch: &Patch) -> Option<CursorPosition> {
867 let loc = locate_edited_line(patch, -1)?;
868
869 let (line, col) = match &loc.patch_line {
870 PatchLine::Addition(content) => (loc.target_line_number, content.len()),
871 PatchLine::Deletion(_) => (loc.target_line_number, 1),
872 _ => return None,
873 };
874
875 Some(CursorPosition {
876 file: loc.filename,
877 line,
878 column: col,
879 })
880}
881
882/// Locate the beginning of the first edit in a patch.
883fn locate_beginning_of_first_edit(patch: &Patch) -> Option<CursorPosition> {
884 let loc = locate_edited_line(patch, 0)?;
885
886 let hunk = patch.hunks.get(loc.hunk_index)?;
887 let column = if loc.line_index_within_hunk > 0 {
888 if let Some(prev_line) = hunk.lines.get(loc.line_index_within_hunk - 1) {
889 let content = match prev_line {
890 PatchLine::Context(s) | PatchLine::Addition(s) | PatchLine::Deletion(s) => s,
891 _ => return None,
892 };
893 content.len().max(1) - 1
894 } else {
895 0
896 }
897 } else {
898 0
899 };
900
901 let line = loc.target_line_number.saturating_sub(1).max(1);
902
903 Some(CursorPosition {
904 file: loc.filename,
905 line,
906 column,
907 })
908}
909
910/// Sample cursor position according to the following rules:
911/// 1. 50% chance of cursor being at the end of the source patch
912/// 2. 50% chance of cursor being at the beginning of the target patch
913pub fn sample_cursor_position(patch: &Patch, split_commit: &SplitCommit) -> Option<CursorPosition> {
914 // Try end of history first
915 let src_patch = Patch::parse_unified_diff(&split_commit.source_patch);
916 if let Some(cursor) = locate_end_of_last_edit(&src_patch) {
917 return Some(cursor);
918 }
919
920 // Try beginning of target
921 let tgt_patch = Patch::parse_unified_diff(&split_commit.target_patch);
922 if let Some(cursor) = locate_beginning_of_first_edit(&tgt_patch) {
923 return Some(cursor);
924 }
925
926 // Fallback: use the original patch
927 locate_end_of_last_edit(patch)
928}
929
930/// Get cursor excerpt from the patches.
931///
932/// This extracts the lines around the cursor position with a cursor marker.
933pub fn get_cursor_excerpt(
934 cursor: &CursorPosition,
935 source_patch: &str,
936 target_patch: &str,
937) -> Option<String> {
938 let mut excerpt_lines: Vec<String> = Vec::new();
939 let mut excerpt_first_line: usize = 0;
940
941 // Search in the last hunk of source patch
942 let src = Patch::parse_unified_diff(source_patch);
943 if let Some(loc) = locate_edited_line(&src, -1) {
944 if loc.filename == cursor.file && loc.target_line_number == cursor.line {
945 if let Some(hunk) = src.hunks.get(loc.hunk_index) {
946 excerpt_first_line = hunk.new_start as usize;
947 for line in &hunk.lines {
948 match line {
949 PatchLine::Addition(s) | PatchLine::Context(s) => {
950 excerpt_lines.push(s.clone());
951 }
952 _ => {}
953 }
954 }
955 // If hunk only has deletions (file deletion), include deletion lines
956 if excerpt_lines.is_empty() {
957 excerpt_first_line = hunk.old_start as usize;
958 for line in &hunk.lines {
959 match line {
960 PatchLine::Deletion(s) => {
961 excerpt_lines.push(s.clone());
962 }
963 _ => {}
964 }
965 }
966 }
967 }
968 }
969 }
970
971 // Search in target patch if not found
972 if excerpt_lines.is_empty() {
973 let tgt = Patch::parse_unified_diff(target_patch);
974 // Search all hunks for the cursor file, not just the first edit's hunk
975 for hunk in &tgt.hunks {
976 if hunk.filename == cursor.file {
977 excerpt_first_line = hunk.new_start as usize;
978 // First try to collect deletions and context (what exists before edits)
979 for line in &hunk.lines {
980 match line {
981 PatchLine::Deletion(s) | PatchLine::Context(s) => {
982 excerpt_lines.push(s.clone());
983 }
984 _ => {}
985 }
986 }
987 // If hunk only has additions (no deletions/context), include all lines
988 // This handles cases like adding to an empty file or section
989 if excerpt_lines.is_empty() {
990 for line in &hunk.lines {
991 match line {
992 PatchLine::Addition(s)
993 | PatchLine::Deletion(s)
994 | PatchLine::Context(s) => {
995 excerpt_lines.push(s.clone());
996 }
997 _ => {}
998 }
999 }
1000 }
1001 if !excerpt_lines.is_empty() {
1002 break;
1003 }
1004 }
1005 }
1006 }
1007
1008 // Also search source patch hunks if still not found (for fallback cursor case)
1009 if excerpt_lines.is_empty() {
1010 for hunk in &src.hunks {
1011 if hunk.filename == cursor.file {
1012 excerpt_first_line = hunk.new_start as usize;
1013 for line in &hunk.lines {
1014 match line {
1015 PatchLine::Addition(s) | PatchLine::Context(s) => {
1016 excerpt_lines.push(s.clone());
1017 }
1018 _ => {}
1019 }
1020 }
1021 // If hunk only has deletions, include deletion lines
1022 if excerpt_lines.is_empty() {
1023 excerpt_first_line = hunk.old_start as usize;
1024 for line in &hunk.lines {
1025 match line {
1026 PatchLine::Deletion(s) => {
1027 excerpt_lines.push(s.clone());
1028 }
1029 _ => {}
1030 }
1031 }
1032 }
1033 if !excerpt_lines.is_empty() {
1034 break;
1035 }
1036 }
1037 }
1038 }
1039
1040 if excerpt_lines.is_empty() {
1041 return None;
1042 }
1043
1044 // Add cursor marker
1045 for (i, line) in excerpt_lines.iter_mut().enumerate() {
1046 let line_num = excerpt_first_line + i;
1047 if line_num == cursor.line {
1048 let col = cursor.column.min(line.len());
1049 // Ensure we split at a valid UTF-8 character boundary
1050 let col = if line.is_char_boundary(col) {
1051 col
1052 } else {
1053 // Find the nearest valid character boundary
1054 (0..=col)
1055 .rev()
1056 .find(|&i| line.is_char_boundary(i))
1057 .unwrap_or(0)
1058 };
1059 let (before, after) = line.split_at(col);
1060 *line = format!("{}<|user_cursor|>{}", before, after);
1061 break;
1062 }
1063 }
1064
1065 Some(excerpt_lines.join("\n"))
1066}
1067
1068#[cfg(test)]
1069mod tests {
1070 use std::path::Path;
1071
1072 use edit_prediction::example_spec::ExampleSpec;
1073
1074 use super::*;
1075
1076 #[test]
1077 fn test_tokenize() {
1078 let tokens = tokenize("hello world");
1079 assert_eq!(tokens, vec!["hello", " ", "world"]);
1080
1081 let tokens = tokenize("foo_bar123 + baz");
1082 assert_eq!(tokens, vec!["foo_", "bar123", " ", "+", " ", "baz"]);
1083
1084 let tokens = tokenize("print(\"hello\")");
1085 assert_eq!(tokens, vec!["print", "(", "\"", "hello", "\"", ")"]);
1086
1087 let tokens = tokenize("hello_world");
1088 assert_eq!(tokens, vec!["hello_", "world"]);
1089
1090 let tokens = tokenize("fn();");
1091 assert_eq!(tokens, vec!["fn", "(", ")", ";"]);
1092 }
1093
1094 #[test]
1095 fn test_fuzzy_ratio() {
1096 assert_eq!(fuzzy_ratio("hello", "hello"), 100);
1097 assert_eq!(fuzzy_ratio("", ""), 100);
1098 assert!(fuzzy_ratio("hello", "world") < 50);
1099 assert!(fuzzy_ratio("hello world", "hello worl") > 80);
1100 }
1101
1102 #[test]
1103 fn test_split_ordered_commit() {
1104 let commit = r#"// First change
1105--- a/test.rs
1106+++ b/test.rs
1107@@ -1,3 +1,4 @@
1108 fn main() {
1109+ println!("hello");
1110+ println!("world");
1111 }
1112"#;
1113 let patch = Patch::parse_unified_diff(commit);
1114 let stats = patch.stats();
1115 assert_eq!(stats.added, 2);
1116
1117 let (source, target) = split_ordered_commit(commit, 1);
1118
1119 // Source should have 1 addition
1120 let src_patch = Patch::parse_unified_diff(&source);
1121 assert_eq!(src_patch.stats().added, 1);
1122
1123 // Target should have 1 addition
1124 let tgt_patch = Patch::parse_unified_diff(&target);
1125 assert_eq!(tgt_patch.stats().added, 1);
1126 }
1127
1128 #[test]
1129 fn test_split_ordered_commit_with_deletions() {
1130 let commit = r#"// Change
1131--- a/test.rs
1132+++ b/test.rs
1133@@ -1,3 +1,3 @@
1134 fn main() {
1135- println!("old");
1136+ println!("new");
1137 }
1138"#;
1139 let patch = Patch::parse_unified_diff(commit);
1140 let stats = patch.stats();
1141 assert_eq!(stats.added, 1);
1142 assert_eq!(stats.removed, 1);
1143
1144 // Split at position 1 (after the deletion)
1145 let (source, target) = split_ordered_commit(commit, 1);
1146
1147 let src_patch = Patch::parse_unified_diff(&source);
1148 let tgt_patch = Patch::parse_unified_diff(&target);
1149
1150 // Source should have the deletion
1151 assert_eq!(src_patch.stats().removed, 1);
1152 // Target should have the addition
1153 assert_eq!(tgt_patch.stats().added, 1);
1154 }
1155
1156 #[test]
1157 fn test_generate_evaluation_example() {
1158 let commit = r#"commit abc123
1159Author: Test <test@example.com>
1160Date: Mon Jan 1 00:00:00 2024
1161
1162 Test commit
1163
1164////////////////////////////////////////////////////////////////////////////////
1165// Add greeting
1166////////////////////////////////////////////////////////////////////////////////
1167--- a/test.rs
1168+++ b/test.rs
1169@@ -1,3 +1,5 @@
1170 fn main() {
1171+ println!("hello");
1172+ println!("world");
1173 }
1174"#;
1175
1176 let result = generate_evaluation_example_from_ordered_commit(
1177 commit,
1178 "https://github.com/test/repo",
1179 "abc123",
1180 Some(SplitPoint::Fraction(0.5)),
1181 Some(42),
1182 None,
1183 );
1184
1185 assert!(result.is_ok());
1186 let case = result.unwrap();
1187 assert_eq!(case.repository_url, "https://github.com/test/repo");
1188 assert_eq!(case.revision, "abc123~1");
1189 assert!(!case.edit_history.is_empty());
1190 }
1191
1192 #[test]
1193 fn test_generate_evaluation_example_reproducible() {
1194 let commit = r#"////////////////////////////////////////////////////////////////////////////////
1195// Add greeting
1196////////////////////////////////////////////////////////////////////////////////
1197--- a/test.rs
1198+++ b/test.rs
1199@@ -1,3 +1,5 @@
1200 fn main() {
1201+ println!("hello");
1202+ println!("world");
1203 }
1204"#;
1205
1206 // Run twice with the same seed
1207 let result1 = generate_evaluation_example_from_ordered_commit(
1208 commit,
1209 "https://github.com/test/repo",
1210 "abc123",
1211 Some(SplitPoint::Fraction(0.5)),
1212 Some(12345),
1213 None,
1214 )
1215 .unwrap();
1216
1217 let result2 = generate_evaluation_example_from_ordered_commit(
1218 commit,
1219 "https://github.com/test/repo",
1220 "abc123",
1221 Some(SplitPoint::Fraction(0.5)),
1222 Some(12345),
1223 None,
1224 )
1225 .unwrap();
1226
1227 // Results should be identical
1228 assert_eq!(result1.edit_history, result2.edit_history);
1229 assert_eq!(result1.expected_patches, result2.expected_patches);
1230 assert_eq!(result1.cursor_position, result2.cursor_position);
1231 }
1232
1233 #[test]
1234 fn test_cursor_position_display() {
1235 let cursor = CursorPosition {
1236 file: "src/main.rs".to_string(),
1237 line: 42,
1238 column: 10,
1239 };
1240 assert_eq!(cursor.to_string(), "src/main.rs:42:10");
1241 }
1242
1243 #[test]
1244 fn test_imitate_human_edits_no_change_when_no_replacement() {
1245 // Source and target patches that don't form a replacement pattern
1246 let source = r#"--- a/test.rs
1247+++ b/test.rs
1248@@ -1,3 +1,4 @@
1249 fn main() {
1250+ println!("hello");
1251 }
1252"#;
1253 let target = r#"--- a/test.rs
1254+++ b/test.rs
1255@@ -1,3 +1,4 @@
1256 fn main() {
1257+ println!("world");
1258 }
1259"#;
1260
1261 let (new_src, new_tgt, cursor) = imitate_human_edits(source, target, 42);
1262
1263 // Should return unchanged when not a replacement pattern
1264 assert_eq!(new_src, source);
1265 assert_eq!(new_tgt, target);
1266 assert!(cursor.is_none());
1267 }
1268
1269 #[test]
1270 fn test_split_point_fraction() {
1271 let commit = r#"// Change
1272--- a/test.rs
1273+++ b/test.rs
1274@@ -1,5 +1,10 @@
1275 fn main() {
1276+ line1();
1277+ line2();
1278+ line3();
1279+ line4();
1280+ line5();
1281 }
1282"#;
1283
1284 // Split at 20% should give first edit in source
1285 let result = generate_evaluation_example_from_ordered_commit(
1286 commit,
1287 "",
1288 "hash",
1289 Some(SplitPoint::Fraction(0.2)),
1290 Some(1),
1291 None,
1292 );
1293
1294 assert!(result.is_ok());
1295 let case = result.unwrap();
1296
1297 // Source should have some edits
1298 let src_patch = Patch::parse_unified_diff(&case.edit_history);
1299 assert!(src_patch.stats().added > 0);
1300 }
1301
1302 #[test]
1303 fn test_split_point_index() {
1304 let commit = r#"// Change
1305--- a/test.rs
1306+++ b/test.rs
1307@@ -1,5 +1,10 @@
1308 fn main() {
1309+ line1();
1310+ line2();
1311+ line3();
1312+ line4();
1313+ line5();
1314 }
1315"#;
1316
1317 // Split at index 2 should give first 2 edits in source
1318 // With pure insertion handling, source gets 2 original + 1 partial = 3 additions
1319 let result = generate_evaluation_example_from_ordered_commit(
1320 commit,
1321 "",
1322 "hash",
1323 Some(SplitPoint::Index(2)),
1324 Some(1),
1325 None,
1326 );
1327
1328 assert!(result.is_ok());
1329 let case = result.unwrap();
1330
1331 let src_patch = Patch::parse_unified_diff(&case.edit_history);
1332 // Pure insertion adds a partial line, so we expect 3 (2 original + 1 partial)
1333 assert_eq!(src_patch.stats().added, 3);
1334 }
1335
1336 #[test]
1337 fn test_cursor_excerpt_contains_marker() {
1338 let commit = r#"////////////////////////////////////////////////////////////////////////////////
1339// Add code
1340////////////////////////////////////////////////////////////////////////////////
1341--- a/test.rs
1342+++ b/test.rs
1343@@ -1,3 +1,5 @@
1344 fn main() {
1345+ println!("hello");
1346+ println!("world");
1347 }
1348"#;
1349
1350 let result = generate_evaluation_example_from_ordered_commit(
1351 commit,
1352 "",
1353 "hash",
1354 Some(SplitPoint::Fraction(0.5)),
1355 Some(42),
1356 None,
1357 )
1358 .unwrap();
1359
1360 // Cursor excerpt should contain the cursor marker
1361 assert!(
1362 result.cursor_position.contains("<|user_cursor|>"),
1363 "Cursor excerpt should contain marker: {}",
1364 result.cursor_position
1365 );
1366 }
1367
1368 #[test]
1369 fn test_evaluation_case_json_serialization() {
1370 let case = ExampleSpec {
1371 name: "test-abc123".to_string(),
1372 repository_url: "https://github.com/test/repo".to_string(),
1373 revision: "abc123~1".to_string(),
1374 edit_history: "patch1".to_string(),
1375 // cursor_position: "file.rs:10:5".to_string(),
1376 cursor_path: Path::new("file.rs").into(),
1377 cursor_position: "some code<|user_cursor|>".to_string(),
1378 expected_patches: vec!["patch".to_string()],
1379 tags: vec![],
1380 reasoning: None,
1381 uncommitted_diff: String::new(),
1382 };
1383
1384 let json = serde_json::to_string(&case).unwrap();
1385 let deserialized: ExampleSpec = serde_json::from_str(&json).unwrap();
1386
1387 assert_eq!(case.repository_url, deserialized.repository_url);
1388 assert_eq!(case.revision, deserialized.revision);
1389 assert_eq!(case.cursor_position, deserialized.cursor_position);
1390 }
1391
1392 #[test]
1393 fn test_empty_commit_returns_error() {
1394 let commit = "";
1395
1396 let result = generate_evaluation_example_from_ordered_commit(
1397 commit,
1398 "",
1399 "hash",
1400 Some(SplitPoint::Fraction(0.5)),
1401 Some(1),
1402 None,
1403 );
1404
1405 assert!(result.is_err());
1406 }
1407
1408 #[test]
1409 fn test_header_filtering() {
1410 let commit = r#"commit abc123
1411Author: Test
1412Date: Today
1413
1414 Message
1415
1416diff --git a/test.rs b/test.rs
1417index 123..456 789
1418////////////////////////////////////////////////////////////////////////////////
1419// First group
1420////////////////////////////////////////////////////////////////////////////////
1421--- a/test.rs
1422+++ b/test.rs
1423@@ -1,3 +1,4 @@
1424 fn main() {
1425+ code();
1426 }
1427"#;
1428
1429 let result = generate_evaluation_example_from_ordered_commit(
1430 commit,
1431 "",
1432 "hash",
1433 Some(SplitPoint::Index(1)),
1434 Some(1),
1435 None,
1436 );
1437
1438 assert!(result.is_ok());
1439 let case = result.unwrap();
1440
1441 // The edit history should contain the group header (// lines)
1442 // but not the commit metadata
1443 assert!(!case.edit_history.contains("Author:"));
1444 assert!(!case.edit_history.contains("Date:"));
1445 }
1446
1447 #[test]
1448 fn test_position_weight() {
1449 // High weight positions (natural pause points)
1450 assert_eq!(position_weight("foo(", 4), 10); // After '('
1451 assert_eq!(position_weight("a, b", 2), 10); // After ','
1452 assert_eq!(position_weight("x;", 2), 10); // After ';'
1453 assert_eq!(position_weight("a: b", 2), 10); // After ':'
1454 assert_eq!(position_weight("[", 1), 10); // After '['
1455 assert_eq!(position_weight("{", 1), 10); // After '{'
1456
1457 // High weight for closing brackets
1458 assert_eq!(position_weight("foo)", 4), 8); // After ')'
1459 assert_eq!(position_weight("]", 1), 8); // After ']'
1460 assert_eq!(position_weight("}", 1), 8); // After '}'
1461
1462 // High weight at end of identifier
1463 assert_eq!(position_weight("foo ", 3), 8); // End of 'foo' before space
1464 assert_eq!(position_weight("bar(", 3), 8); // End of 'bar' before '('
1465
1466 // Medium weight for operators
1467 assert_eq!(position_weight("a + b", 3), 5); // After '+'
1468 assert_eq!(position_weight("x.", 2), 5); // After '.'
1469 assert_eq!(position_weight("a=b", 2), 5); // After '='
1470
1471 // Medium weight for whitespace
1472 assert_eq!(position_weight("a ", 2), 6); // After space
1473
1474 // Low weight mid-identifier
1475 assert_eq!(position_weight("foobar", 3), 1); // Mid-identifier 'foo|bar'
1476
1477 // Edge cases
1478 assert_eq!(position_weight("", 0), 1); // Empty string
1479 assert_eq!(position_weight("a", 0), 1); // Position 0
1480 }
1481
1482 #[test]
1483 fn test_weighted_select() {
1484 // Test that weighted selection returns correct indices
1485 let weights = vec![1, 10, 1];
1486
1487 // With total weight 12, seed 0 should select index 0
1488 // seed 0 % 12 = 0, cumulative: 1 at idx 0, so returns 0
1489 assert_eq!(weighted_select(&weights, 0), 0);
1490
1491 // seed 1 % 12 = 1, cumulative: 1 at idx 0 (1 < 1 is false), 11 at idx 1 (1 < 11 is true)
1492 assert_eq!(weighted_select(&weights, 1), 1);
1493
1494 // seed 10 % 12 = 10, cumulative: 1, 11 at idx 1 (10 < 11 is true)
1495 assert_eq!(weighted_select(&weights, 10), 1);
1496
1497 // seed 11 % 12 = 11, cumulative: 1, 11 at idx 1 (11 < 11 is false), 12 at idx 2 (11 < 12 is true)
1498 assert_eq!(weighted_select(&weights, 11), 2);
1499
1500 // Empty weights should return 0
1501 let empty: Vec<u32> = vec![];
1502 assert_eq!(weighted_select(&empty, 42), 0);
1503
1504 // Single weight should always return index 0
1505 let single = vec![10];
1506 assert_eq!(weighted_select(&single, 0), 0);
1507 assert_eq!(weighted_select(&single, 100), 0);
1508 }
1509
1510 #[test]
1511 fn test_weighted_split_prefers_natural_boundaries() {
1512 // Test that with different seeds, weighted selection tends to prefer
1513 // positions after punctuation over mid-identifier positions
1514 let text_with_punctuation = "foo(bar, baz)";
1515 let text_mid_identifier = "foobar";
1516
1517 // Position after '(' should have high weight
1518 let weight_after_paren = position_weight(text_with_punctuation, 4);
1519 // Position after ',' should have high weight
1520 let weight_after_comma = position_weight(text_with_punctuation, 8);
1521 // Position mid-identifier should have low weight
1522 let weight_mid_ident = position_weight(text_mid_identifier, 3);
1523
1524 assert!(
1525 weight_after_paren > weight_mid_ident,
1526 "After '(' ({}) should be weighted higher than mid-identifier ({})",
1527 weight_after_paren,
1528 weight_mid_ident
1529 );
1530 assert!(
1531 weight_after_comma > weight_mid_ident,
1532 "After ',' ({}) should be weighted higher than mid-identifier ({})",
1533 weight_after_comma,
1534 weight_mid_ident
1535 );
1536 }
1537
1538 #[test]
1539 fn test_imitate_human_edits_pure_insertion() {
1540 // Source patch is empty (no edits yet)
1541 // Target patch has a pure insertion (adding a new line)
1542 let source = r#"--- a/test.rs
1543+++ b/test.rs
1544@@ -1,2 +1,2 @@
1545 fn main() {
1546 }
1547"#;
1548 let target = r#"--- a/test.rs
1549+++ b/test.rs
1550@@ -1,2 +1,3 @@
1551 fn main() {
1552+ println!("debug");
1553 }
1554"#;
1555
1556 let (new_src, new_tgt, cursor) = imitate_human_edits(source, target, 42);
1557
1558 // Should have transformed the patches
1559 assert_ne!(
1560 new_src, source,
1561 "Source should be modified for pure insertion"
1562 );
1563 assert_ne!(
1564 new_tgt, target,
1565 "Target should be modified for pure insertion"
1566 );
1567 assert!(cursor.is_some(), "Cursor should be set");
1568
1569 // Source should now have a partial addition
1570 let src_patch = Patch::parse_unified_diff(&new_src);
1571 assert!(
1572 src_patch.stats().added > 0,
1573 "Source should have added lines"
1574 );
1575
1576 // Target should have both a deletion (of partial) and addition (of full)
1577 let tgt_patch = Patch::parse_unified_diff(&new_tgt);
1578 assert!(
1579 tgt_patch.stats().removed > 0,
1580 "Target should have removed lines (partial)"
1581 );
1582 assert!(
1583 tgt_patch.stats().added > 0,
1584 "Target should have added lines (full)"
1585 );
1586
1587 // The cursor should be in test.rs
1588 let cursor = cursor.unwrap();
1589 assert_eq!(cursor.file, "test.rs");
1590 }
1591
1592 #[test]
1593 fn test_imitate_human_edits_pure_insertion_empty_source() {
1594 // Source patch has no hunks at all
1595 let source = "";
1596 let target = r#"--- a/test.rs
1597+++ b/test.rs
1598@@ -1,2 +1,3 @@
1599 fn main() {
1600+ println!("hello");
1601 }
1602"#;
1603
1604 let (new_src, _new_tgt, cursor) = imitate_human_edits(source, target, 123);
1605
1606 // Should have created a source patch with partial insertion
1607 assert!(!new_src.is_empty(), "Source should not be empty");
1608 assert!(cursor.is_some(), "Cursor should be set");
1609
1610 let src_patch = Patch::parse_unified_diff(&new_src);
1611 assert!(
1612 src_patch.stats().added > 0,
1613 "Source should have added lines"
1614 );
1615 }
1616
1617 #[test]
1618 fn test_imitate_human_edits_pure_insertion_intermediate_content() {
1619 // Verify the actual intermediate content is a realistic partial typing state
1620 let source = "";
1621 let target = r#"--- a/test.rs
1622+++ b/test.rs
1623@@ -1,2 +1,3 @@
1624 fn main() {
1625+ println!("hello world");
1626 }
1627"#;
1628
1629 // Test with multiple seeds to see different split points
1630 let mut found_partial = false;
1631 for seed in 1..=50 {
1632 let (new_src, new_tgt, cursor) = imitate_human_edits(source, target, seed);
1633
1634 if cursor.is_some() {
1635 let src_patch = Patch::parse_unified_diff(&new_src);
1636 let tgt_patch = Patch::parse_unified_diff(&new_tgt);
1637
1638 // Find the added line in source
1639 for hunk in &src_patch.hunks {
1640 for line in &hunk.lines {
1641 if let PatchLine::Addition(content) = line {
1642 // The partial line should be a prefix of the full line
1643 let full_line = " println!(\"hello world\");";
1644 if content != full_line && full_line.starts_with(content) {
1645 found_partial = true;
1646
1647 // Verify target has the partial as deletion
1648 let mut has_deletion = false;
1649 for tgt_hunk in &tgt_patch.hunks {
1650 for tgt_line in &tgt_hunk.lines {
1651 if let PatchLine::Deletion(del_content) = tgt_line {
1652 if del_content == content {
1653 has_deletion = true;
1654 }
1655 }
1656 }
1657 }
1658 assert!(
1659 has_deletion,
1660 "Target should have deletion of partial line"
1661 );
1662 }
1663 }
1664 }
1665 }
1666 }
1667 }
1668
1669 assert!(
1670 found_partial,
1671 "At least one seed should produce a partial intermediate state"
1672 );
1673 }
1674
1675 #[test]
1676 fn test_cursor_excerpt_with_multibyte_utf8() {
1677 // Test that cursor excerpt handles multi-byte UTF-8 characters correctly
1678 // The Chinese character '第' is 3 bytes (0..3)
1679 let cursor = CursorPosition {
1680 file: "test.md".to_string(),
1681 line: 1,
1682 column: 1, // Byte index 1 is inside '第' (bytes 0..3)
1683 };
1684
1685 let source_patch = r#"--- a/test.md
1686+++ b/test.md
1687@@ -1,1 +1,1 @@
1688+第 14 章 Flask 工作原理与机制解析**
1689"#;
1690
1691 let target_patch = "";
1692
1693 // This should not panic even though column=1 is not a char boundary
1694 let result = get_cursor_excerpt(&cursor, source_patch, target_patch);
1695
1696 // The function should handle the invalid byte index gracefully
1697 if let Some(excerpt) = result {
1698 assert!(
1699 excerpt.contains("<|user_cursor|>"),
1700 "Cursor excerpt should contain marker"
1701 );
1702 // The marker should be placed at a valid character boundary
1703 // (either at the start or after '第')
1704 }
1705 }
1706}