1use anyhow::{Context as _, Result};
2use serde::{Deserialize, Serialize};
3use std::{borrow::Cow, fmt::Write as _, mem, path::Path, sync::Arc};
4use telemetry_events::EditPredictionRating;
5
6pub use zeta_prompt::udiff::{
7 CURSOR_POSITION_MARKER, encode_cursor_in_patch, extract_cursor_from_patch,
8};
9pub const INLINE_CURSOR_MARKER: &str = "<|user_cursor|>";
10
11/// Maximum cursor file size to capture (64KB).
12/// Files larger than this will not have their content captured,
13/// falling back to git-based loading.
14pub const MAX_CURSOR_FILE_SIZE: usize = 64 * 1024;
15
16#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
17pub struct ExampleSpec {
18 #[serde(default)]
19 pub name: String,
20 pub repository_url: String,
21 pub revision: String,
22 #[serde(default, skip_serializing_if = "Vec::is_empty")]
23 pub tags: Vec<String>,
24 #[serde(default, skip_serializing_if = "Option::is_none")]
25 pub reasoning: Option<String>,
26 #[serde(default)]
27 pub uncommitted_diff: String,
28 pub cursor_path: Arc<Path>,
29 pub cursor_position: String,
30 pub edit_history: String,
31 pub expected_patches: Vec<String>,
32 #[serde(default, skip_serializing_if = "Option::is_none")]
33 pub rejected_patch: Option<String>,
34 #[serde(default, skip_serializing_if = "Option::is_none")]
35 pub telemetry: Option<TelemetrySource>,
36 #[serde(default, skip_serializing_if = "Vec::is_empty")]
37 pub human_feedback: Vec<HumanFeedback>,
38 #[serde(default, skip_serializing_if = "Option::is_none")]
39 pub rating: Option<EditPredictionRating>,
40}
41
42#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
43pub struct HumanFeedback {
44 pub message: String,
45}
46
47/// Metadata for examples sourced from production telemetry (rejected predictions).
48#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
49pub struct TelemetrySource {
50 pub request_id: String,
51 pub device_id: String,
52 pub time: String,
53 pub rejection_reason: String,
54 pub was_shown: bool,
55}
56
57const REASONING_HEADING: &str = "Reasoning";
58const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
59const EDIT_HISTORY_HEADING: &str = "Edit History";
60const CURSOR_POSITION_HEADING: &str = "Cursor Position";
61const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
62const REJECTED_PATCH_HEADING: &str = "Rejected Patch";
63const ACCEPTED_PREDICTION_MARKER: &str = "// User accepted prediction:";
64
65#[derive(Serialize, Deserialize)]
66struct FrontMatter<'a> {
67 repository_url: Cow<'a, str>,
68 revision: Cow<'a, str>,
69 #[serde(default, skip_serializing_if = "Vec::is_empty")]
70 tags: Vec<String>,
71}
72
73impl ExampleSpec {
74 /// Generate a sanitized filename for this example.
75 pub fn filename(&self) -> String {
76 self.name
77 .chars()
78 .map(|c| match c {
79 ' ' | ':' | '~' | '^' | '?' | '*' | '[' | '\\' | '@' | '{' | '/' | '<' | '>'
80 | '|' | '"' => '-',
81 c => c,
82 })
83 .collect()
84 }
85
86 /// Format this example spec as markdown.
87 pub fn to_markdown(&self) -> String {
88 use std::fmt::Write as _;
89
90 let front_matter = FrontMatter {
91 repository_url: Cow::Borrowed(&self.repository_url),
92 revision: Cow::Borrowed(&self.revision),
93 tags: self.tags.clone(),
94 };
95 let front_matter_toml =
96 toml::to_string_pretty(&front_matter).unwrap_or_else(|_| String::new());
97
98 let mut markdown = String::new();
99
100 _ = writeln!(markdown, "+++");
101 markdown.push_str(&front_matter_toml);
102 if !markdown.ends_with('\n') {
103 markdown.push('\n');
104 }
105 _ = writeln!(markdown, "+++");
106 markdown.push('\n');
107
108 _ = writeln!(markdown, "# {}", self.name);
109 markdown.push('\n');
110
111 if let Some(reasoning) = &self.reasoning {
112 _ = writeln!(markdown, "## {}", REASONING_HEADING);
113 markdown.push('\n');
114 markdown.push_str(reasoning);
115 if !markdown.ends_with('\n') {
116 markdown.push('\n');
117 }
118 markdown.push('\n');
119 }
120
121 if !self.uncommitted_diff.is_empty() {
122 _ = writeln!(markdown, "## {}", UNCOMMITTED_DIFF_HEADING);
123 _ = writeln!(markdown);
124 _ = writeln!(markdown, "```diff");
125 markdown.push_str(&self.uncommitted_diff);
126 if !markdown.ends_with('\n') {
127 markdown.push('\n');
128 }
129 _ = writeln!(markdown, "```");
130 markdown.push('\n');
131 }
132
133 _ = writeln!(markdown, "## {}", EDIT_HISTORY_HEADING);
134 _ = writeln!(markdown);
135
136 if self.edit_history.is_empty() {
137 _ = writeln!(markdown, "(No edit history)");
138 _ = writeln!(markdown);
139 } else {
140 _ = writeln!(markdown, "```diff");
141 markdown.push_str(&self.edit_history);
142 if !markdown.ends_with('\n') {
143 markdown.push('\n');
144 }
145 _ = writeln!(markdown, "```");
146 markdown.push('\n');
147 }
148
149 _ = writeln!(markdown, "## {}", CURSOR_POSITION_HEADING);
150 _ = writeln!(markdown);
151 _ = writeln!(markdown, "```{}", self.cursor_path.to_string_lossy());
152 markdown.push_str(&self.cursor_position);
153 if !markdown.ends_with('\n') {
154 markdown.push('\n');
155 }
156 _ = writeln!(markdown, "```");
157 markdown.push('\n');
158
159 _ = writeln!(markdown, "## {}", EXPECTED_PATCH_HEADING);
160 markdown.push('\n');
161 for patch in &self.expected_patches {
162 _ = writeln!(markdown, "```diff");
163 markdown.push_str(patch);
164 if !markdown.ends_with('\n') {
165 markdown.push('\n');
166 }
167 _ = writeln!(markdown, "```");
168 markdown.push('\n');
169 }
170
171 if let Some(rejected_patch) = &self.rejected_patch {
172 _ = writeln!(markdown, "## {}", REJECTED_PATCH_HEADING);
173 markdown.push('\n');
174 _ = writeln!(markdown, "```diff");
175 markdown.push_str(rejected_patch);
176 if !markdown.ends_with('\n') {
177 markdown.push('\n');
178 }
179 _ = writeln!(markdown, "```");
180 markdown.push('\n');
181 }
182
183 markdown
184 }
185
186 /// Parse an example spec from markdown.
187 pub fn from_markdown(mut input: &str) -> anyhow::Result<Self> {
188 use pulldown_cmark::{CodeBlockKind, CowStr, Event, HeadingLevel, Parser, Tag, TagEnd};
189
190 let mut spec = ExampleSpec {
191 name: String::new(),
192 repository_url: String::new(),
193 revision: String::new(),
194 tags: Vec::new(),
195 reasoning: None,
196 uncommitted_diff: String::new(),
197 cursor_path: Path::new("").into(),
198 cursor_position: String::new(),
199 edit_history: String::new(),
200 expected_patches: Vec::new(),
201 rejected_patch: None,
202 telemetry: None,
203 human_feedback: Vec::new(),
204 rating: None,
205 };
206
207 if let Some(rest) = input.strip_prefix("+++\n")
208 && let Some((front_matter, rest)) = rest.split_once("+++\n")
209 {
210 if let Ok(data) = toml::from_str::<FrontMatter<'_>>(front_matter) {
211 spec.repository_url = data.repository_url.into_owned();
212 spec.revision = data.revision.into_owned();
213 spec.tags = data.tags;
214 }
215 input = rest.trim_start();
216 }
217
218 let parser = Parser::new(input);
219 let mut text = String::new();
220 let mut block_info: CowStr = "".into();
221
222 #[derive(PartialEq)]
223 enum Section {
224 Start,
225 UncommittedDiff,
226 EditHistory,
227 CursorPosition,
228 ExpectedPatch,
229 RejectedPatch,
230 Other,
231 }
232
233 let mut current_section = Section::Start;
234 let mut next_edit_predicted = false;
235
236 for event in parser {
237 match event {
238 Event::Text(line) => {
239 text.push_str(&line);
240 }
241 Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
242 spec.name = mem::take(&mut text);
243 }
244 Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
245 let title = mem::take(&mut text);
246 current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
247 Section::UncommittedDiff
248 } else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
249 Section::EditHistory
250 } else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
251 Section::CursorPosition
252 } else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
253 Section::ExpectedPatch
254 } else if title.eq_ignore_ascii_case(REJECTED_PATCH_HEADING) {
255 Section::RejectedPatch
256 } else {
257 Section::Other
258 };
259 }
260 Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
261 mem::take(&mut text);
262 }
263 Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
264 mem::take(&mut text);
265 }
266 Event::End(TagEnd::Heading(level)) => {
267 anyhow::bail!("Unexpected heading level: {level}");
268 }
269 Event::Start(Tag::CodeBlock(kind)) => {
270 if current_section == Section::EditHistory
271 && text.trim() == ACCEPTED_PREDICTION_MARKER
272 {
273 next_edit_predicted = true;
274 }
275 text.clear();
276 match kind {
277 CodeBlockKind::Fenced(info) => {
278 block_info = info;
279 }
280 CodeBlockKind::Indented => {
281 anyhow::bail!("Unexpected indented codeblock");
282 }
283 };
284 }
285 Event::Start(_) => {
286 text.clear();
287 block_info = "".into();
288 }
289 Event::End(TagEnd::CodeBlock) => {
290 let block_info = block_info.trim();
291 match current_section {
292 Section::UncommittedDiff => {
293 spec.uncommitted_diff = mem::take(&mut text);
294 }
295 Section::EditHistory => {
296 if next_edit_predicted {
297 spec.edit_history
298 .push_str(&format!("{}\n", ACCEPTED_PREDICTION_MARKER));
299 next_edit_predicted = false;
300 }
301 spec.edit_history.push_str(&mem::take(&mut text));
302 }
303 Section::CursorPosition => {
304 spec.cursor_path = Path::new(block_info).into();
305 spec.cursor_position = mem::take(&mut text);
306 }
307 Section::ExpectedPatch => {
308 spec.expected_patches.push(mem::take(&mut text));
309 }
310 Section::RejectedPatch => {
311 spec.rejected_patch = Some(mem::take(&mut text));
312 }
313 Section::Start | Section::Other => {}
314 }
315 }
316 _ => {}
317 }
318 }
319
320 if spec.cursor_path.as_ref() == Path::new("") || spec.cursor_position.is_empty() {
321 anyhow::bail!("Missing cursor position codeblock");
322 }
323
324 Ok(spec)
325 }
326
327 /// Returns the excerpt of text around the cursor, and the offset of the cursor within that
328 /// excerpt.
329 ///
330 /// The cursor's position is marked with a special comment that appears
331 /// below the cursor line, which contains the string `[CURSOR_POSITION]`,
332 /// preceded by an arrow marking the cursor's column. The arrow can be
333 /// either:
334 /// - `^` - The cursor column is at the position of the `^` character (pointing up to the cursor)
335 /// - `<` - The cursor column is at the first non-whitespace character on that line.
336 pub fn cursor_excerpt(&self) -> Result<(String, usize)> {
337 let input = &self.cursor_position;
338
339 // Check for inline cursor marker first
340 if let Some(inline_offset) = input.find(INLINE_CURSOR_MARKER) {
341 let excerpt = input[..inline_offset].to_string()
342 + &input[inline_offset + INLINE_CURSOR_MARKER.len()..];
343 return Ok((excerpt, inline_offset));
344 }
345
346 let marker_offset = input
347 .find(CURSOR_POSITION_MARKER)
348 .context("missing [CURSOR_POSITION] marker")?;
349 let marker_line_start = input[..marker_offset]
350 .rfind('\n')
351 .map(|pos| pos + 1)
352 .unwrap_or(0);
353 let marker_line_end = input[marker_line_start..]
354 .find('\n')
355 .map(|pos| marker_line_start + pos + 1)
356 .unwrap_or(input.len());
357 let marker_line = &input[marker_line_start..marker_line_end].trim_end_matches('\n');
358
359 let cursor_column = if let Some(cursor_offset) = marker_line.find('^') {
360 cursor_offset
361 } else if let Some(less_than_pos) = marker_line.find('<') {
362 marker_line
363 .find(|c: char| !c.is_whitespace())
364 .unwrap_or(less_than_pos)
365 } else {
366 anyhow::bail!(
367 "cursor position marker line must contain '^' or '<' before [CURSOR_POSITION]"
368 );
369 };
370
371 let mut excerpt = input[..marker_line_start].to_string() + &input[marker_line_end..];
372 excerpt.truncate(excerpt.trim_end_matches('\n').len());
373
374 // The cursor is on the line above the marker line.
375 let cursor_line_end = marker_line_start.saturating_sub(1);
376 let cursor_line_start = excerpt[..cursor_line_end]
377 .rfind('\n')
378 .map(|pos| pos + 1)
379 .unwrap_or(0);
380 let cursor_offset = cursor_line_start + cursor_column;
381
382 Ok((excerpt, cursor_offset))
383 }
384
385 /// Sets the cursor position excerpt from a plain excerpt and cursor byte offset.
386 ///
387 /// The `line_comment_prefix` is used to format the marker line as a comment.
388 /// If the cursor column is less than the comment prefix length, the `<` format is used.
389 /// Otherwise, the `^` format is used.
390 pub fn set_cursor_excerpt(
391 &mut self,
392 excerpt: &str,
393 cursor_offset: usize,
394 line_comment_prefix: &str,
395 ) {
396 // Find which line the cursor is on and its column
397 let cursor_line_start = excerpt[..cursor_offset]
398 .rfind('\n')
399 .map(|pos| pos + 1)
400 .unwrap_or(0);
401 let cursor_line_end = excerpt[cursor_line_start..]
402 .find('\n')
403 .map(|pos| cursor_line_start + pos + 1)
404 .unwrap_or(excerpt.len());
405 let cursor_line = &excerpt[cursor_line_start..cursor_line_end];
406 let cursor_line_indent = &cursor_line[..cursor_line.len() - cursor_line.trim_start().len()];
407 let cursor_column = cursor_offset - cursor_line_start;
408
409 // Build the marker line
410 let mut marker_line = String::new();
411 if cursor_column < line_comment_prefix.len() {
412 for _ in 0..cursor_column {
413 marker_line.push(' ');
414 }
415 marker_line.push_str(line_comment_prefix);
416 write!(marker_line, " <{}", CURSOR_POSITION_MARKER).unwrap();
417 } else {
418 if cursor_column >= cursor_line_indent.len() + line_comment_prefix.len() {
419 marker_line.push_str(cursor_line_indent);
420 }
421 marker_line.push_str(line_comment_prefix);
422 while marker_line.len() < cursor_column {
423 marker_line.push(' ');
424 }
425 write!(marker_line, "^{}", CURSOR_POSITION_MARKER).unwrap();
426 }
427
428 // Build the final cursor_position string
429 let mut result = String::with_capacity(excerpt.len() + marker_line.len() + 2);
430 result.push_str(&excerpt[..cursor_line_end]);
431 if !result.ends_with('\n') {
432 result.push('\n');
433 }
434 result.push_str(&marker_line);
435 if cursor_line_end < excerpt.len() {
436 result.push('\n');
437 result.push_str(&excerpt[cursor_line_end..]);
438 }
439
440 self.cursor_position = result;
441 }
442
443 /// Returns all of the possible expected patches for this example, each with an optional
444 /// cursor offset.
445 ///
446 /// The cursor offset is an offset within the new text (after applying the patch), relative
447 /// to the start of the hunk.
448 ///
449 /// In the serialized representation of this example, the cursor position is represented
450 /// using a comment line in the diff, beginning with `#`, and containing a `[CURSOR_POSITION]`
451 /// marker with the same format as the [`Self::cursor_excerpt`].
452 pub fn expected_patches_with_cursor_positions(&self) -> Vec<(String, Option<usize>)> {
453 self.expected_patches
454 .iter()
455 .map(|patch| extract_cursor_from_patch(patch))
456 .collect()
457 }
458
459 pub fn set_expected_patches_with_cursor_positions(
460 &mut self,
461 patches: Vec<(String, Option<usize>)>,
462 ) {
463 self.expected_patches = patches
464 .into_iter()
465 .map(|(patch, cursor_offset)| encode_cursor_in_patch(&patch, cursor_offset))
466 .collect();
467 }
468}
469
470#[cfg(test)]
471mod tests {
472 use super::*;
473 use indoc::indoc;
474
475 #[test]
476 fn test_cursor_excerpt_with_caret() {
477 let mut spec = ExampleSpec {
478 name: String::new(),
479 repository_url: String::new(),
480 revision: String::new(),
481 tags: Vec::new(),
482 reasoning: None,
483 uncommitted_diff: String::new(),
484 cursor_path: Path::new("test.rs").into(),
485 cursor_position: String::new(),
486 edit_history: String::new(),
487 expected_patches: Vec::new(),
488 rejected_patch: None,
489 telemetry: None,
490 human_feedback: Vec::new(),
491 rating: None,
492 };
493
494 // Cursor before `42`
495 let excerpt = indoc! {"
496 fn main() {
497 let x = 42;
498 println!(\"{}\", x);
499 }"
500 };
501 let offset = excerpt.find("42").unwrap();
502 let position_string = indoc! {"
503 fn main() {
504 let x = 42;
505 // ^[CURSOR_POSITION]
506 println!(\"{}\", x);
507 }"
508 }
509 .to_string();
510
511 spec.set_cursor_excerpt(excerpt, offset, "//");
512 assert_eq!(spec.cursor_position, position_string);
513 assert_eq!(
514 spec.cursor_excerpt().unwrap(),
515 (excerpt.to_string(), offset)
516 );
517
518 // Cursor after `l` in `let`
519 let offset = excerpt.find("et x").unwrap();
520 let position_string = indoc! {"
521 fn main() {
522 let x = 42;
523 // ^[CURSOR_POSITION]
524 println!(\"{}\", x);
525 }"
526 }
527 .to_string();
528
529 spec.set_cursor_excerpt(excerpt, offset, "//");
530 assert_eq!(spec.cursor_position, position_string);
531 assert_eq!(
532 spec.cursor_excerpt().unwrap(),
533 (excerpt.to_string(), offset)
534 );
535
536 // Cursor before `let`
537 let offset = excerpt.find("let").unwrap();
538 let position_string = indoc! {"
539 fn main() {
540 let x = 42;
541 // ^[CURSOR_POSITION]
542 println!(\"{}\", x);
543 }"
544 }
545 .to_string();
546
547 spec.set_cursor_excerpt(excerpt, offset, "//");
548 assert_eq!(spec.cursor_position, position_string);
549 assert_eq!(
550 spec.cursor_excerpt().unwrap(),
551 (excerpt.to_string(), offset)
552 );
553
554 // Cursor at beginning of the line with `let`
555 let offset = excerpt.find(" let").unwrap();
556 let position_string = indoc! {"
557 fn main() {
558 let x = 42;
559 // <[CURSOR_POSITION]
560 println!(\"{}\", x);
561 }"
562 }
563 .to_string();
564
565 spec.set_cursor_excerpt(excerpt, offset, "//");
566 assert_eq!(spec.cursor_position, position_string);
567 assert_eq!(
568 spec.cursor_excerpt().unwrap(),
569 (excerpt.to_string(), offset)
570 );
571
572 // Cursor at end of line, after the semicolon
573 let offset = excerpt.find(';').unwrap() + 1;
574 let position_string = indoc! {"
575 fn main() {
576 let x = 42;
577 // ^[CURSOR_POSITION]
578 println!(\"{}\", x);
579 }"
580 }
581 .to_string();
582
583 spec.set_cursor_excerpt(excerpt, offset, "//");
584 assert_eq!(spec.cursor_position, position_string);
585 assert_eq!(
586 spec.cursor_excerpt().unwrap(),
587 (excerpt.to_string(), offset)
588 );
589
590 // Caret at end of file (no trailing newline)
591 let excerpt = indoc! {"
592 fn main() {
593 let x = 42;"
594 };
595 let offset = excerpt.find(';').unwrap() + 1;
596 let position_string = indoc! {"
597 fn main() {
598 let x = 42;
599 // ^[CURSOR_POSITION]"
600 }
601 .to_string();
602
603 spec.set_cursor_excerpt(excerpt, offset, "//");
604 assert_eq!(spec.cursor_position, position_string);
605 assert_eq!(
606 spec.cursor_excerpt().unwrap(),
607 (excerpt.to_string(), offset)
608 );
609 }
610
611 #[test]
612 fn test_cursor_excerpt_with_inline_marker() {
613 let mut spec = ExampleSpec {
614 name: String::new(),
615 repository_url: String::new(),
616 revision: String::new(),
617 tags: Vec::new(),
618 reasoning: None,
619 uncommitted_diff: String::new(),
620 cursor_path: Path::new("test.rs").into(),
621 cursor_position: String::new(),
622 edit_history: String::new(),
623 expected_patches: Vec::new(),
624 rejected_patch: None,
625 telemetry: None,
626 human_feedback: Vec::new(),
627 rating: None,
628 };
629
630 // Cursor before `42` using inline marker
631 spec.cursor_position = indoc! {"
632 fn main() {
633 let x = <|user_cursor|>42;
634 println!(\"{}\", x);
635 }"
636 }
637 .to_string();
638
639 let expected_excerpt = indoc! {"
640 fn main() {
641 let x = 42;
642 println!(\"{}\", x);
643 }"
644 };
645 let expected_offset = expected_excerpt.find("42").unwrap();
646
647 assert_eq!(
648 spec.cursor_excerpt().unwrap(),
649 (expected_excerpt.to_string(), expected_offset)
650 );
651
652 // Cursor at beginning of line
653 spec.cursor_position = indoc! {"
654 fn main() {
655 <|user_cursor|> let x = 42;
656 }"
657 }
658 .to_string();
659
660 let expected_excerpt = indoc! {"
661 fn main() {
662 let x = 42;
663 }"
664 };
665 let expected_offset = expected_excerpt.find(" let").unwrap();
666
667 assert_eq!(
668 spec.cursor_excerpt().unwrap(),
669 (expected_excerpt.to_string(), expected_offset)
670 );
671
672 // Cursor at end of file
673 spec.cursor_position = "fn main() {}<|user_cursor|>".to_string();
674 let expected_excerpt = "fn main() {}";
675 let expected_offset = expected_excerpt.len();
676
677 assert_eq!(
678 spec.cursor_excerpt().unwrap(),
679 (expected_excerpt.to_string(), expected_offset)
680 );
681 }
682
683 #[test]
684 fn test_expected_patches_with_cursor_positions() {
685 let mut spec = ExampleSpec {
686 name: String::new(),
687 repository_url: String::new(),
688 revision: String::new(),
689 tags: Vec::new(),
690 reasoning: None,
691 uncommitted_diff: String::new(),
692 cursor_path: Path::new("test.rs").into(),
693 cursor_position: String::new(),
694 edit_history: String::new(),
695 expected_patches: Vec::new(),
696 rejected_patch: None,
697 telemetry: None,
698 human_feedback: Vec::new(),
699 rating: None,
700 };
701
702 let new_content = indoc! {r#"
703 // prints a greeting
704 fn main() {
705 println!("hello, {}", );
706 let x = 42;
707 }
708 "#};
709 let cursor_offset = new_content.find(");").unwrap();
710
711 let clean_patch = indoc! {r#"
712 --- a/test.rs
713 +++ b/test.rs
714 @@ -1,3 +1,4 @@
715 +// prints a greeting
716 fn main() {
717 - println!("hi");
718 + println!("hello, {}", );
719 let x = 42;
720 }
721 "#}
722 .to_string();
723
724 let encoded_patch = indoc! {r#"
725 --- a/test.rs
726 +++ b/test.rs
727 @@ -1,3 +1,4 @@
728 +// prints a greeting
729 fn main() {
730 - println!("hi");
731 + println!("hello, {}", );
732 # ^[CURSOR_POSITION]
733 let x = 42;
734 }
735 "#}
736 .to_string();
737
738 spec.set_expected_patches_with_cursor_positions(vec![(
739 clean_patch.clone(),
740 Some(cursor_offset),
741 )]);
742 assert_eq!(spec.expected_patches, vec![encoded_patch]);
743
744 let results = spec.expected_patches_with_cursor_positions();
745 assert_eq!(results, vec![(clean_patch.clone(), Some(cursor_offset))]);
746
747 spec.set_expected_patches_with_cursor_positions(vec![(clean_patch.clone(), None)]);
748 assert_eq!(spec.expected_patches, vec![clean_patch.clone()]);
749
750 let results = spec.expected_patches_with_cursor_positions();
751 assert_eq!(results, vec![(clean_patch, None)]);
752 }
753
754 #[test]
755 fn test_encode_cursor_in_patch_is_idempotent() {
756 let patch = indoc! {r#"
757 --- a/test.rs
758 +++ b/test.rs
759 @@ -1,2 +1,2 @@
760 -fn old() {}
761 +fn new_name() {}
762 # ^[CURSOR_POSITION]
763 "#};
764
765 let cursor_offset = "fn new_name() {}".find("name").unwrap();
766 let encoded_once = encode_cursor_in_patch(patch, Some(cursor_offset));
767 let encoded_twice = encode_cursor_in_patch(&encoded_once, Some(cursor_offset));
768
769 assert_eq!(encoded_once, encoded_twice);
770 assert_eq!(
771 encoded_once
772 .lines()
773 .filter(|line| line.contains(CURSOR_POSITION_MARKER))
774 .count(),
775 1
776 );
777 }
778
779 #[test]
780 fn test_from_markdown_accepted_prediction_marker() {
781 let markdown = indoc! {r#"
782 +++
783 repository_url = "https://github.com/example/repo"
784 revision = "abc123"
785 +++
786
787 ## Edit History
788
789 ```diff
790 --- a/src/main.rs
791 +++ b/src/main.rs
792 @@ -1,3 +1,3 @@
793 -fn hello() {}
794 +fn hello_world() {}
795 ```
796
797 // User accepted prediction:
798 ```diff
799 --- a/src/main.rs
800 +++ b/src/main.rs
801 @@ -1,3 +1,3 @@
802 -fn hello_world() {}
803 +fn hello_world() { println!("hi"); }
804 ```
805
806 ```diff
807 --- a/src/main.rs
808 +++ b/src/main.rs
809 @@ -1,3 +1,3 @@
810 -fn hello_world() { println!("hi"); }
811 +fn hello_world() { println!("hello"); }
812 ```
813
814 ## Cursor Position
815
816 ```src/main.rs
817 fn hello_world() { println!("hello"); }
818 # ^[CURSOR_POSITION]
819 ```
820
821 ## Expected Patch
822
823 ```diff
824 --- a/src/main.rs
825 +++ b/src/main.rs
826 @@ -1,3 +1,3 @@
827 -fn hello_world() { println!("hello"); }
828 +fn hello_world() { println!("hello, world!"); }
829 ```
830 "#};
831
832 let spec = ExampleSpec::from_markdown(markdown).unwrap();
833
834 // The first diff should NOT have the marker
835 assert!(spec.edit_history.starts_with("--- a/src/main.rs"));
836
837 // The second diff should be preceded by the accepted prediction marker
838 assert!(
839 spec.edit_history
840 .contains("// User accepted prediction:\n--- a/src/main.rs")
841 );
842
843 // Count occurrences of the marker - should be exactly one
844 let marker_count = spec
845 .edit_history
846 .matches("// User accepted prediction:")
847 .count();
848 assert_eq!(marker_count, 1);
849
850 // The third diff should NOT have the marker
851 // Verify all three diffs are present
852 let diff_count = spec.edit_history.matches("--- a/src/main.rs").count();
853 assert_eq!(diff_count, 3);
854 }
855}