1use crate::udiff::DiffLine;
2use anyhow::{Context as _, Result};
3use serde::{Deserialize, Serialize};
4use std::{borrow::Cow, fmt::Write as _, mem, ops::Range, path::Path, sync::Arc};
5use telemetry_events::EditPredictionRating;
6
7pub const CURSOR_POSITION_MARKER: &str = "[CURSOR_POSITION]";
8pub const INLINE_CURSOR_MARKER: &str = "<|user_cursor|>";
9
10/// Maximum cursor file size to capture (64KB).
11/// Files larger than this will not have their content captured,
12/// falling back to git-based loading.
13pub const MAX_CURSOR_FILE_SIZE: usize = 64 * 1024;
14
15/// Encodes a cursor position into a diff patch by adding a comment line with a caret
16/// pointing to the cursor column.
17///
18/// The cursor offset is relative to the start of the new text content (additions and context lines).
19/// Returns the patch with cursor marker comment lines inserted after the relevant addition line.
20pub fn encode_cursor_in_patch(patch: &str, cursor_offset: Option<usize>) -> String {
21 let Some(cursor_offset) = cursor_offset else {
22 return patch.to_string();
23 };
24
25 let mut result = String::new();
26 let mut line_start_offset = 0usize;
27
28 for line in patch.lines() {
29 if !result.is_empty() {
30 result.push('\n');
31 }
32 result.push_str(line);
33
34 match DiffLine::parse(line) {
35 DiffLine::Addition(content) => {
36 let line_end_offset = line_start_offset + content.len();
37
38 if cursor_offset >= line_start_offset && cursor_offset <= line_end_offset {
39 let cursor_column = cursor_offset - line_start_offset;
40
41 result.push('\n');
42 result.push('#');
43 for _ in 0..cursor_column {
44 result.push(' ');
45 }
46 write!(result, "^{}", CURSOR_POSITION_MARKER).unwrap();
47 }
48
49 line_start_offset = line_end_offset + 1;
50 }
51 DiffLine::Context(content) => {
52 line_start_offset += content.len() + 1;
53 }
54 _ => {}
55 }
56 }
57
58 if patch.ends_with('\n') {
59 result.push('\n');
60 }
61
62 result
63}
64
65#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
66pub struct ExampleSpec {
67 #[serde(default)]
68 pub name: String,
69 pub repository_url: String,
70 pub revision: String,
71 #[serde(default, skip_serializing_if = "Vec::is_empty")]
72 pub tags: Vec<String>,
73 #[serde(default, skip_serializing_if = "Option::is_none")]
74 pub reasoning: Option<String>,
75 #[serde(default)]
76 pub uncommitted_diff: String,
77 pub cursor_path: Arc<Path>,
78 pub cursor_position: String,
79 pub edit_history: String,
80 pub expected_patches: Vec<String>,
81 #[serde(default, skip_serializing_if = "Option::is_none")]
82 pub rejected_patch: Option<String>,
83 #[serde(default, skip_serializing_if = "Option::is_none")]
84 pub captured_prompt_input: Option<CapturedPromptInput>,
85 #[serde(default, skip_serializing_if = "Option::is_none")]
86 pub telemetry: Option<TelemetrySource>,
87 #[serde(default, skip_serializing_if = "Vec::is_empty")]
88 pub human_feedback: Vec<HumanFeedback>,
89 #[serde(default, skip_serializing_if = "Option::is_none")]
90 pub rating: Option<EditPredictionRating>,
91}
92
93#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
94pub struct HumanFeedback {
95 pub message: String,
96}
97
98/// Metadata for examples sourced from production telemetry (rejected predictions).
99#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
100pub struct TelemetrySource {
101 pub request_id: String,
102 pub device_id: String,
103 pub time: String,
104 pub rejection_reason: String,
105 pub was_shown: bool,
106}
107
108/// All data needed to run format_prompt without loading the project.
109#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
110pub struct CapturedPromptInput {
111 pub cursor_file_content: String,
112 pub cursor_offset: usize,
113 pub cursor_row: u32,
114 pub cursor_column: u32,
115 #[serde(default, skip_serializing_if = "Option::is_none")]
116 pub excerpt_start_row: Option<u32>,
117 pub events: Vec<CapturedEvent>,
118 pub related_files: Vec<CapturedRelatedFile>,
119 #[serde(default)]
120 pub in_open_source_repo: bool,
121 #[serde(default, skip_serializing_if = "Option::is_none")]
122 pub zed_version: Option<String>,
123}
124
125#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
126pub struct CapturedEvent {
127 pub path: Arc<Path>,
128 pub old_path: Arc<Path>,
129 pub diff: String,
130 pub predicted: bool,
131 #[serde(default)]
132 pub in_open_source_repo: bool,
133}
134
135impl CapturedEvent {
136 pub fn to_event(&self) -> zeta_prompt::Event {
137 zeta_prompt::Event::BufferChange {
138 path: self.path.clone(),
139 old_path: self.old_path.clone(),
140 diff: self.diff.clone(),
141 predicted: self.predicted,
142 in_open_source_repo: self.in_open_source_repo,
143 }
144 }
145}
146
147#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
148pub struct CapturedRelatedFile {
149 pub path: Arc<Path>,
150 pub max_row: u32,
151 pub excerpts: Vec<CapturedRelatedExcerpt>,
152}
153
154impl CapturedRelatedFile {
155 pub fn to_related_file(&self) -> zeta_prompt::RelatedFile {
156 zeta_prompt::RelatedFile {
157 path: self.path.clone(),
158 max_row: self.max_row,
159 in_open_source_repo: false,
160 excerpts: self
161 .excerpts
162 .iter()
163 .map(|e| zeta_prompt::RelatedExcerpt {
164 row_range: e.row_range.clone(),
165 text: e.text.clone().into(),
166 })
167 .collect(),
168 }
169 }
170}
171
172#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
173pub struct CapturedRelatedExcerpt {
174 pub row_range: Range<u32>,
175 pub text: String,
176}
177
178const REASONING_HEADING: &str = "Reasoning";
179const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
180const EDIT_HISTORY_HEADING: &str = "Edit History";
181const CURSOR_POSITION_HEADING: &str = "Cursor Position";
182const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
183const REJECTED_PATCH_HEADING: &str = "Rejected Patch";
184const ACCEPTED_PREDICTION_MARKER: &str = "// User accepted prediction:";
185
186#[derive(Serialize, Deserialize)]
187struct FrontMatter<'a> {
188 repository_url: Cow<'a, str>,
189 revision: Cow<'a, str>,
190 #[serde(default, skip_serializing_if = "Vec::is_empty")]
191 tags: Vec<String>,
192}
193
194impl ExampleSpec {
195 /// Generate a sanitized filename for this example.
196 pub fn filename(&self) -> String {
197 self.name
198 .chars()
199 .map(|c| match c {
200 ' ' | ':' | '~' | '^' | '?' | '*' | '[' | '\\' | '@' | '{' | '/' | '<' | '>'
201 | '|' | '"' => '-',
202 c => c,
203 })
204 .collect()
205 }
206
207 /// Format this example spec as markdown.
208 pub fn to_markdown(&self) -> String {
209 use std::fmt::Write as _;
210
211 let front_matter = FrontMatter {
212 repository_url: Cow::Borrowed(&self.repository_url),
213 revision: Cow::Borrowed(&self.revision),
214 tags: self.tags.clone(),
215 };
216 let front_matter_toml =
217 toml::to_string_pretty(&front_matter).unwrap_or_else(|_| String::new());
218
219 let mut markdown = String::new();
220
221 _ = writeln!(markdown, "+++");
222 markdown.push_str(&front_matter_toml);
223 if !markdown.ends_with('\n') {
224 markdown.push('\n');
225 }
226 _ = writeln!(markdown, "+++");
227 markdown.push('\n');
228
229 _ = writeln!(markdown, "# {}", self.name);
230 markdown.push('\n');
231
232 if let Some(reasoning) = &self.reasoning {
233 _ = writeln!(markdown, "## {}", REASONING_HEADING);
234 markdown.push('\n');
235 markdown.push_str(reasoning);
236 if !markdown.ends_with('\n') {
237 markdown.push('\n');
238 }
239 markdown.push('\n');
240 }
241
242 if !self.uncommitted_diff.is_empty() {
243 _ = writeln!(markdown, "## {}", UNCOMMITTED_DIFF_HEADING);
244 _ = writeln!(markdown);
245 _ = writeln!(markdown, "```diff");
246 markdown.push_str(&self.uncommitted_diff);
247 if !markdown.ends_with('\n') {
248 markdown.push('\n');
249 }
250 _ = writeln!(markdown, "```");
251 markdown.push('\n');
252 }
253
254 _ = writeln!(markdown, "## {}", EDIT_HISTORY_HEADING);
255 _ = writeln!(markdown);
256
257 if self.edit_history.is_empty() {
258 _ = writeln!(markdown, "(No edit history)");
259 _ = writeln!(markdown);
260 } else {
261 _ = writeln!(markdown, "```diff");
262 markdown.push_str(&self.edit_history);
263 if !markdown.ends_with('\n') {
264 markdown.push('\n');
265 }
266 _ = writeln!(markdown, "```");
267 markdown.push('\n');
268 }
269
270 _ = writeln!(markdown, "## {}", CURSOR_POSITION_HEADING);
271 _ = writeln!(markdown);
272 _ = writeln!(markdown, "```{}", self.cursor_path.to_string_lossy());
273 markdown.push_str(&self.cursor_position);
274 if !markdown.ends_with('\n') {
275 markdown.push('\n');
276 }
277 _ = writeln!(markdown, "```");
278 markdown.push('\n');
279
280 _ = writeln!(markdown, "## {}", EXPECTED_PATCH_HEADING);
281 markdown.push('\n');
282 for patch in &self.expected_patches {
283 _ = writeln!(markdown, "```diff");
284 markdown.push_str(patch);
285 if !markdown.ends_with('\n') {
286 markdown.push('\n');
287 }
288 _ = writeln!(markdown, "```");
289 markdown.push('\n');
290 }
291
292 if let Some(rejected_patch) = &self.rejected_patch {
293 _ = writeln!(markdown, "## {}", REJECTED_PATCH_HEADING);
294 markdown.push('\n');
295 _ = writeln!(markdown, "```diff");
296 markdown.push_str(rejected_patch);
297 if !markdown.ends_with('\n') {
298 markdown.push('\n');
299 }
300 _ = writeln!(markdown, "```");
301 markdown.push('\n');
302 }
303
304 markdown
305 }
306
307 /// Parse an example spec from markdown.
308 pub fn from_markdown(mut input: &str) -> anyhow::Result<Self> {
309 use pulldown_cmark::{CodeBlockKind, CowStr, Event, HeadingLevel, Parser, Tag, TagEnd};
310
311 let mut spec = ExampleSpec {
312 name: String::new(),
313 repository_url: String::new(),
314 revision: String::new(),
315 tags: Vec::new(),
316 reasoning: None,
317 uncommitted_diff: String::new(),
318 cursor_path: Path::new("").into(),
319 cursor_position: String::new(),
320 edit_history: String::new(),
321 expected_patches: Vec::new(),
322 rejected_patch: None,
323 captured_prompt_input: None,
324 telemetry: None,
325 human_feedback: Vec::new(),
326 rating: None,
327 };
328
329 if let Some(rest) = input.strip_prefix("+++\n")
330 && let Some((front_matter, rest)) = rest.split_once("+++\n")
331 {
332 if let Ok(data) = toml::from_str::<FrontMatter<'_>>(front_matter) {
333 spec.repository_url = data.repository_url.into_owned();
334 spec.revision = data.revision.into_owned();
335 spec.tags = data.tags;
336 }
337 input = rest.trim_start();
338 }
339
340 let parser = Parser::new(input);
341 let mut text = String::new();
342 let mut block_info: CowStr = "".into();
343
344 #[derive(PartialEq)]
345 enum Section {
346 Start,
347 UncommittedDiff,
348 EditHistory,
349 CursorPosition,
350 ExpectedPatch,
351 RejectedPatch,
352 Other,
353 }
354
355 let mut current_section = Section::Start;
356 let mut next_edit_predicted = false;
357
358 for event in parser {
359 match event {
360 Event::Text(line) => {
361 text.push_str(&line);
362 }
363 Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
364 spec.name = mem::take(&mut text);
365 }
366 Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
367 let title = mem::take(&mut text);
368 current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
369 Section::UncommittedDiff
370 } else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
371 Section::EditHistory
372 } else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
373 Section::CursorPosition
374 } else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
375 Section::ExpectedPatch
376 } else if title.eq_ignore_ascii_case(REJECTED_PATCH_HEADING) {
377 Section::RejectedPatch
378 } else {
379 Section::Other
380 };
381 }
382 Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
383 mem::take(&mut text);
384 }
385 Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
386 mem::take(&mut text);
387 }
388 Event::End(TagEnd::Heading(level)) => {
389 anyhow::bail!("Unexpected heading level: {level}");
390 }
391 Event::Start(Tag::CodeBlock(kind)) => {
392 if current_section == Section::EditHistory
393 && text.trim() == ACCEPTED_PREDICTION_MARKER
394 {
395 next_edit_predicted = true;
396 }
397 text.clear();
398 match kind {
399 CodeBlockKind::Fenced(info) => {
400 block_info = info;
401 }
402 CodeBlockKind::Indented => {
403 anyhow::bail!("Unexpected indented codeblock");
404 }
405 };
406 }
407 Event::Start(_) => {
408 text.clear();
409 block_info = "".into();
410 }
411 Event::End(TagEnd::CodeBlock) => {
412 let block_info = block_info.trim();
413 match current_section {
414 Section::UncommittedDiff => {
415 spec.uncommitted_diff = mem::take(&mut text);
416 }
417 Section::EditHistory => {
418 if next_edit_predicted {
419 spec.edit_history
420 .push_str(&format!("{}\n", ACCEPTED_PREDICTION_MARKER));
421 next_edit_predicted = false;
422 }
423 spec.edit_history.push_str(&mem::take(&mut text));
424 }
425 Section::CursorPosition => {
426 spec.cursor_path = Path::new(block_info).into();
427 spec.cursor_position = mem::take(&mut text);
428 }
429 Section::ExpectedPatch => {
430 spec.expected_patches.push(mem::take(&mut text));
431 }
432 Section::RejectedPatch => {
433 spec.rejected_patch = Some(mem::take(&mut text));
434 }
435 Section::Start | Section::Other => {}
436 }
437 }
438 _ => {}
439 }
440 }
441
442 if spec.cursor_path.as_ref() == Path::new("") || spec.cursor_position.is_empty() {
443 anyhow::bail!("Missing cursor position codeblock");
444 }
445
446 Ok(spec)
447 }
448
449 /// Returns the excerpt of text around the cursor, and the offset of the cursor within that
450 /// excerpt.
451 ///
452 /// The cursor's position is marked with a special comment that appears
453 /// below the cursor line, which contains the string `[CURSOR_POSITION]`,
454 /// preceded by an arrow marking the cursor's column. The arrow can be
455 /// either:
456 /// - `^` - The cursor column is at the position of the `^` character (pointing up to the cursor)
457 /// - `<` - The cursor column is at the first non-whitespace character on that line.
458 pub fn cursor_excerpt(&self) -> Result<(String, usize)> {
459 let input = &self.cursor_position;
460
461 // Check for inline cursor marker first
462 if let Some(inline_offset) = input.find(INLINE_CURSOR_MARKER) {
463 let excerpt = input[..inline_offset].to_string()
464 + &input[inline_offset + INLINE_CURSOR_MARKER.len()..];
465 return Ok((excerpt, inline_offset));
466 }
467
468 let marker_offset = input
469 .find(CURSOR_POSITION_MARKER)
470 .context("missing [CURSOR_POSITION] marker")?;
471 let marker_line_start = input[..marker_offset]
472 .rfind('\n')
473 .map(|pos| pos + 1)
474 .unwrap_or(0);
475 let marker_line_end = input[marker_line_start..]
476 .find('\n')
477 .map(|pos| marker_line_start + pos + 1)
478 .unwrap_or(input.len());
479 let marker_line = &input[marker_line_start..marker_line_end].trim_end_matches('\n');
480
481 let cursor_column = if let Some(cursor_offset) = marker_line.find('^') {
482 cursor_offset
483 } else if let Some(less_than_pos) = marker_line.find('<') {
484 marker_line
485 .find(|c: char| !c.is_whitespace())
486 .unwrap_or(less_than_pos)
487 } else {
488 anyhow::bail!(
489 "cursor position marker line must contain '^' or '<' before [CURSOR_POSITION]"
490 );
491 };
492
493 let mut excerpt = input[..marker_line_start].to_string() + &input[marker_line_end..];
494 excerpt.truncate(excerpt.trim_end_matches('\n').len());
495
496 // The cursor is on the line above the marker line.
497 let cursor_line_end = marker_line_start.saturating_sub(1);
498 let cursor_line_start = excerpt[..cursor_line_end]
499 .rfind('\n')
500 .map(|pos| pos + 1)
501 .unwrap_or(0);
502 let cursor_offset = cursor_line_start + cursor_column;
503
504 Ok((excerpt, cursor_offset))
505 }
506
507 /// Sets the cursor position excerpt from a plain excerpt and cursor byte offset.
508 ///
509 /// The `line_comment_prefix` is used to format the marker line as a comment.
510 /// If the cursor column is less than the comment prefix length, the `<` format is used.
511 /// Otherwise, the `^` format is used.
512 pub fn set_cursor_excerpt(
513 &mut self,
514 excerpt: &str,
515 cursor_offset: usize,
516 line_comment_prefix: &str,
517 ) {
518 // Find which line the cursor is on and its column
519 let cursor_line_start = excerpt[..cursor_offset]
520 .rfind('\n')
521 .map(|pos| pos + 1)
522 .unwrap_or(0);
523 let cursor_line_end = excerpt[cursor_line_start..]
524 .find('\n')
525 .map(|pos| cursor_line_start + pos + 1)
526 .unwrap_or(excerpt.len());
527 let cursor_line = &excerpt[cursor_line_start..cursor_line_end];
528 let cursor_line_indent = &cursor_line[..cursor_line.len() - cursor_line.trim_start().len()];
529 let cursor_column = cursor_offset - cursor_line_start;
530
531 // Build the marker line
532 let mut marker_line = String::new();
533 if cursor_column < line_comment_prefix.len() {
534 for _ in 0..cursor_column {
535 marker_line.push(' ');
536 }
537 marker_line.push_str(line_comment_prefix);
538 write!(marker_line, " <{}", CURSOR_POSITION_MARKER).unwrap();
539 } else {
540 if cursor_column >= cursor_line_indent.len() + line_comment_prefix.len() {
541 marker_line.push_str(cursor_line_indent);
542 }
543 marker_line.push_str(line_comment_prefix);
544 while marker_line.len() < cursor_column {
545 marker_line.push(' ');
546 }
547 write!(marker_line, "^{}", CURSOR_POSITION_MARKER).unwrap();
548 }
549
550 // Build the final cursor_position string
551 let mut result = String::with_capacity(excerpt.len() + marker_line.len() + 2);
552 result.push_str(&excerpt[..cursor_line_end]);
553 if !result.ends_with('\n') {
554 result.push('\n');
555 }
556 result.push_str(&marker_line);
557 if cursor_line_end < excerpt.len() {
558 result.push('\n');
559 result.push_str(&excerpt[cursor_line_end..]);
560 }
561
562 self.cursor_position = result;
563 }
564
565 /// Returns all of the possible expected patches for this example, each with an optional
566 /// cursor offset.
567 ///
568 /// The cursor offset is an offset within the new text (after applying the patch), relative
569 /// to the start of the hunk.
570 ///
571 /// In the serialized representation of this example, the cursor position is represented
572 /// using a comment line in the diff, beginning with `#`, and containing a `[CURSOR_POSITION]`
573 /// marker with the same format as the [`Self::cursor_excerpt`].
574 pub fn expected_patches_with_cursor_positions(&self) -> Vec<(String, Option<usize>)> {
575 self.expected_patches
576 .iter()
577 .map(|patch| {
578 let mut clean_patch = String::new();
579 let mut cursor_offset: Option<usize> = None;
580 let mut line_start_offset = 0usize;
581 let mut prev_line_start_offset = 0usize;
582
583 for line in patch.lines() {
584 let diff_line = DiffLine::parse(line);
585
586 match &diff_line {
587 DiffLine::Garbage(content)
588 if content.starts_with('#')
589 && content.contains(CURSOR_POSITION_MARKER) =>
590 {
591 let caret_column = if let Some(caret_pos) = content.find('^') {
592 caret_pos
593 } else if let Some(_) = content.find('<') {
594 0
595 } else {
596 continue;
597 };
598 let cursor_column = caret_column.saturating_sub('#'.len_utf8());
599 cursor_offset = Some(prev_line_start_offset + cursor_column);
600 }
601 _ => {
602 if !clean_patch.is_empty() {
603 clean_patch.push('\n');
604 }
605 clean_patch.push_str(line);
606
607 match diff_line {
608 DiffLine::Addition(content) | DiffLine::Context(content) => {
609 prev_line_start_offset = line_start_offset;
610 line_start_offset += content.len() + 1;
611 }
612 _ => {}
613 }
614 }
615 }
616 }
617
618 if patch.ends_with('\n') && !clean_patch.is_empty() {
619 clean_patch.push('\n');
620 }
621
622 (clean_patch, cursor_offset)
623 })
624 .collect()
625 }
626
627 pub fn set_expected_patches_with_cursor_positions(
628 &mut self,
629 patches: Vec<(String, Option<usize>)>,
630 ) {
631 self.expected_patches = patches
632 .into_iter()
633 .map(|(patch, cursor_offset)| encode_cursor_in_patch(&patch, cursor_offset))
634 .collect();
635 }
636}
637
638#[cfg(test)]
639mod tests {
640 use super::*;
641 use indoc::indoc;
642
643 #[test]
644 fn test_cursor_excerpt_with_caret() {
645 let mut spec = ExampleSpec {
646 name: String::new(),
647 repository_url: String::new(),
648 revision: String::new(),
649 tags: Vec::new(),
650 reasoning: None,
651 uncommitted_diff: String::new(),
652 cursor_path: Path::new("test.rs").into(),
653 cursor_position: String::new(),
654 edit_history: String::new(),
655 expected_patches: Vec::new(),
656 rejected_patch: None,
657 captured_prompt_input: None,
658 telemetry: None,
659 human_feedback: Vec::new(),
660 rating: None,
661 };
662
663 // Cursor before `42`
664 let excerpt = indoc! {"
665 fn main() {
666 let x = 42;
667 println!(\"{}\", x);
668 }"
669 };
670 let offset = excerpt.find("42").unwrap();
671 let position_string = indoc! {"
672 fn main() {
673 let x = 42;
674 // ^[CURSOR_POSITION]
675 println!(\"{}\", x);
676 }"
677 }
678 .to_string();
679
680 spec.set_cursor_excerpt(excerpt, offset, "//");
681 assert_eq!(spec.cursor_position, position_string);
682 assert_eq!(
683 spec.cursor_excerpt().unwrap(),
684 (excerpt.to_string(), offset)
685 );
686
687 // Cursor after `l` in `let`
688 let offset = excerpt.find("et x").unwrap();
689 let position_string = indoc! {"
690 fn main() {
691 let x = 42;
692 // ^[CURSOR_POSITION]
693 println!(\"{}\", x);
694 }"
695 }
696 .to_string();
697
698 spec.set_cursor_excerpt(excerpt, offset, "//");
699 assert_eq!(spec.cursor_position, position_string);
700 assert_eq!(
701 spec.cursor_excerpt().unwrap(),
702 (excerpt.to_string(), offset)
703 );
704
705 // Cursor before `let`
706 let offset = excerpt.find("let").unwrap();
707 let position_string = indoc! {"
708 fn main() {
709 let x = 42;
710 // ^[CURSOR_POSITION]
711 println!(\"{}\", x);
712 }"
713 }
714 .to_string();
715
716 spec.set_cursor_excerpt(excerpt, offset, "//");
717 assert_eq!(spec.cursor_position, position_string);
718 assert_eq!(
719 spec.cursor_excerpt().unwrap(),
720 (excerpt.to_string(), offset)
721 );
722
723 // Cursor at beginning of the line with `let`
724 let offset = excerpt.find(" let").unwrap();
725 let position_string = indoc! {"
726 fn main() {
727 let x = 42;
728 // <[CURSOR_POSITION]
729 println!(\"{}\", x);
730 }"
731 }
732 .to_string();
733
734 spec.set_cursor_excerpt(excerpt, offset, "//");
735 assert_eq!(spec.cursor_position, position_string);
736 assert_eq!(
737 spec.cursor_excerpt().unwrap(),
738 (excerpt.to_string(), offset)
739 );
740
741 // Cursor at end of line, after the semicolon
742 let offset = excerpt.find(';').unwrap() + 1;
743 let position_string = indoc! {"
744 fn main() {
745 let x = 42;
746 // ^[CURSOR_POSITION]
747 println!(\"{}\", x);
748 }"
749 }
750 .to_string();
751
752 spec.set_cursor_excerpt(excerpt, offset, "//");
753 assert_eq!(spec.cursor_position, position_string);
754 assert_eq!(
755 spec.cursor_excerpt().unwrap(),
756 (excerpt.to_string(), offset)
757 );
758
759 // Caret at end of file (no trailing newline)
760 let excerpt = indoc! {"
761 fn main() {
762 let x = 42;"
763 };
764 let offset = excerpt.find(';').unwrap() + 1;
765 let position_string = indoc! {"
766 fn main() {
767 let x = 42;
768 // ^[CURSOR_POSITION]"
769 }
770 .to_string();
771
772 spec.set_cursor_excerpt(excerpt, offset, "//");
773 assert_eq!(spec.cursor_position, position_string);
774 assert_eq!(
775 spec.cursor_excerpt().unwrap(),
776 (excerpt.to_string(), offset)
777 );
778 }
779
780 #[test]
781 fn test_cursor_excerpt_with_inline_marker() {
782 let mut spec = ExampleSpec {
783 name: String::new(),
784 repository_url: String::new(),
785 revision: String::new(),
786 tags: Vec::new(),
787 reasoning: None,
788 uncommitted_diff: String::new(),
789 cursor_path: Path::new("test.rs").into(),
790 cursor_position: String::new(),
791 edit_history: String::new(),
792 expected_patches: Vec::new(),
793 rejected_patch: None,
794 captured_prompt_input: None,
795 telemetry: None,
796 human_feedback: Vec::new(),
797 rating: None,
798 };
799
800 // Cursor before `42` using inline marker
801 spec.cursor_position = indoc! {"
802 fn main() {
803 let x = <|user_cursor|>42;
804 println!(\"{}\", x);
805 }"
806 }
807 .to_string();
808
809 let expected_excerpt = indoc! {"
810 fn main() {
811 let x = 42;
812 println!(\"{}\", x);
813 }"
814 };
815 let expected_offset = expected_excerpt.find("42").unwrap();
816
817 assert_eq!(
818 spec.cursor_excerpt().unwrap(),
819 (expected_excerpt.to_string(), expected_offset)
820 );
821
822 // Cursor at beginning of line
823 spec.cursor_position = indoc! {"
824 fn main() {
825 <|user_cursor|> let x = 42;
826 }"
827 }
828 .to_string();
829
830 let expected_excerpt = indoc! {"
831 fn main() {
832 let x = 42;
833 }"
834 };
835 let expected_offset = expected_excerpt.find(" let").unwrap();
836
837 assert_eq!(
838 spec.cursor_excerpt().unwrap(),
839 (expected_excerpt.to_string(), expected_offset)
840 );
841
842 // Cursor at end of file
843 spec.cursor_position = "fn main() {}<|user_cursor|>".to_string();
844 let expected_excerpt = "fn main() {}";
845 let expected_offset = expected_excerpt.len();
846
847 assert_eq!(
848 spec.cursor_excerpt().unwrap(),
849 (expected_excerpt.to_string(), expected_offset)
850 );
851 }
852
853 #[test]
854 fn test_expected_patches_with_cursor_positions() {
855 let mut spec = ExampleSpec {
856 name: String::new(),
857 repository_url: String::new(),
858 revision: String::new(),
859 tags: Vec::new(),
860 reasoning: None,
861 uncommitted_diff: String::new(),
862 cursor_path: Path::new("test.rs").into(),
863 cursor_position: String::new(),
864 edit_history: String::new(),
865 expected_patches: Vec::new(),
866 rejected_patch: None,
867 captured_prompt_input: None,
868 telemetry: None,
869 human_feedback: Vec::new(),
870 rating: None,
871 };
872
873 let new_content = indoc! {r#"
874 // prints a greeting
875 fn main() {
876 println!("hello, {}", );
877 let x = 42;
878 }
879 "#};
880 let cursor_offset = new_content.find(");").unwrap();
881
882 let clean_patch = indoc! {r#"
883 --- a/test.rs
884 +++ b/test.rs
885 @@ -1,3 +1,4 @@
886 +// prints a greeting
887 fn main() {
888 - println!("hi");
889 + println!("hello, {}", );
890 let x = 42;
891 }
892 "#}
893 .to_string();
894
895 let encoded_patch = indoc! {r#"
896 --- a/test.rs
897 +++ b/test.rs
898 @@ -1,3 +1,4 @@
899 +// prints a greeting
900 fn main() {
901 - println!("hi");
902 + println!("hello, {}", );
903 # ^[CURSOR_POSITION]
904 let x = 42;
905 }
906 "#}
907 .to_string();
908
909 spec.set_expected_patches_with_cursor_positions(vec![(
910 clean_patch.clone(),
911 Some(cursor_offset),
912 )]);
913 assert_eq!(spec.expected_patches, vec![encoded_patch]);
914
915 let results = spec.expected_patches_with_cursor_positions();
916 assert_eq!(results, vec![(clean_patch.clone(), Some(cursor_offset))]);
917
918 spec.set_expected_patches_with_cursor_positions(vec![(clean_patch.clone(), None)]);
919 assert_eq!(spec.expected_patches, vec![clean_patch.clone()]);
920
921 let results = spec.expected_patches_with_cursor_positions();
922 assert_eq!(results, vec![(clean_patch, None)]);
923 }
924
925 #[test]
926 fn test_from_markdown_accepted_prediction_marker() {
927 let markdown = indoc! {r#"
928 +++
929 repository_url = "https://github.com/example/repo"
930 revision = "abc123"
931 +++
932
933 ## Edit History
934
935 ```diff
936 --- a/src/main.rs
937 +++ b/src/main.rs
938 @@ -1,3 +1,3 @@
939 -fn hello() {}
940 +fn hello_world() {}
941 ```
942
943 // User accepted prediction:
944 ```diff
945 --- a/src/main.rs
946 +++ b/src/main.rs
947 @@ -1,3 +1,3 @@
948 -fn hello_world() {}
949 +fn hello_world() { println!("hi"); }
950 ```
951
952 ```diff
953 --- a/src/main.rs
954 +++ b/src/main.rs
955 @@ -1,3 +1,3 @@
956 -fn hello_world() { println!("hi"); }
957 +fn hello_world() { println!("hello"); }
958 ```
959
960 ## Cursor Position
961
962 ```src/main.rs
963 fn hello_world() { println!("hello"); }
964 # ^[CURSOR_POSITION]
965 ```
966
967 ## Expected Patch
968
969 ```diff
970 --- a/src/main.rs
971 +++ b/src/main.rs
972 @@ -1,3 +1,3 @@
973 -fn hello_world() { println!("hello"); }
974 +fn hello_world() { println!("hello, world!"); }
975 ```
976 "#};
977
978 let spec = ExampleSpec::from_markdown(markdown).unwrap();
979
980 // The first diff should NOT have the marker
981 assert!(spec.edit_history.starts_with("--- a/src/main.rs"));
982
983 // The second diff should be preceded by the accepted prediction marker
984 assert!(
985 spec.edit_history
986 .contains("// User accepted prediction:\n--- a/src/main.rs")
987 );
988
989 // Count occurrences of the marker - should be exactly one
990 let marker_count = spec
991 .edit_history
992 .matches("// User accepted prediction:")
993 .count();
994 assert_eq!(marker_count, 1);
995
996 // The third diff should NOT have the marker
997 // Verify all three diffs are present
998 let diff_count = spec.edit_history.matches("--- a/src/main.rs").count();
999 assert_eq!(diff_count, 3);
1000 }
1001}