example_spec.rs

  1use crate::udiff::DiffLine;
  2use anyhow::{Context as _, Result};
  3use serde::{Deserialize, Serialize};
  4use std::{borrow::Cow, fmt::Write as _, mem, ops::Range, path::Path, sync::Arc};
  5use telemetry_events::EditPredictionRating;
  6
  7pub const CURSOR_POSITION_MARKER: &str = "[CURSOR_POSITION]";
  8pub const INLINE_CURSOR_MARKER: &str = "<|user_cursor|>";
  9
 10/// Maximum cursor file size to capture (64KB).
 11/// Files larger than this will not have their content captured,
 12/// falling back to git-based loading.
 13pub const MAX_CURSOR_FILE_SIZE: usize = 64 * 1024;
 14
 15#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
 16pub struct ExampleSpec {
 17    #[serde(default)]
 18    pub name: String,
 19    pub repository_url: String,
 20    pub revision: String,
 21    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 22    pub tags: Vec<String>,
 23    #[serde(default, skip_serializing_if = "Option::is_none")]
 24    pub reasoning: Option<String>,
 25    #[serde(default)]
 26    pub uncommitted_diff: String,
 27    pub cursor_path: Arc<Path>,
 28    pub cursor_position: String,
 29    pub edit_history: String,
 30    pub expected_patches: Vec<String>,
 31    #[serde(default, skip_serializing_if = "Option::is_none")]
 32    pub rejected_patch: Option<String>,
 33    #[serde(default, skip_serializing_if = "Option::is_none")]
 34    pub captured_prompt_input: Option<CapturedPromptInput>,
 35    #[serde(default, skip_serializing_if = "Option::is_none")]
 36    pub telemetry: Option<TelemetrySource>,
 37    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 38    pub human_feedback: Vec<HumanFeedback>,
 39    #[serde(default, skip_serializing_if = "Option::is_none")]
 40    pub rating: Option<EditPredictionRating>,
 41}
 42
 43#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
 44pub struct HumanFeedback {
 45    pub message: String,
 46}
 47
 48/// Metadata for examples sourced from production telemetry (rejected predictions).
 49#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
 50pub struct TelemetrySource {
 51    pub request_id: String,
 52    pub device_id: String,
 53    pub time: String,
 54    pub rejection_reason: String,
 55    pub was_shown: bool,
 56}
 57
 58/// All data needed to run format_prompt without loading the project.
 59#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
 60pub struct CapturedPromptInput {
 61    pub cursor_file_content: String,
 62    pub cursor_offset: usize,
 63    pub cursor_row: u32,
 64    pub cursor_column: u32,
 65    #[serde(default, skip_serializing_if = "Option::is_none")]
 66    pub excerpt_start_row: Option<u32>,
 67    pub events: Vec<CapturedEvent>,
 68    pub related_files: Vec<CapturedRelatedFile>,
 69}
 70
 71#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
 72pub struct CapturedEvent {
 73    pub path: Arc<Path>,
 74    pub old_path: Arc<Path>,
 75    pub diff: String,
 76    pub predicted: bool,
 77    pub in_open_source_repo: bool,
 78}
 79
 80impl CapturedEvent {
 81    pub fn to_event(&self) -> zeta_prompt::Event {
 82        zeta_prompt::Event::BufferChange {
 83            path: self.path.clone(),
 84            old_path: self.old_path.clone(),
 85            diff: self.diff.clone(),
 86            predicted: self.predicted,
 87            in_open_source_repo: self.in_open_source_repo,
 88        }
 89    }
 90}
 91
 92#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
 93pub struct CapturedRelatedFile {
 94    pub path: Arc<Path>,
 95    pub max_row: u32,
 96    pub excerpts: Vec<CapturedRelatedExcerpt>,
 97}
 98
 99impl CapturedRelatedFile {
100    pub fn to_related_file(&self) -> zeta_prompt::RelatedFile {
101        zeta_prompt::RelatedFile {
102            path: self.path.clone(),
103            max_row: self.max_row,
104            excerpts: self
105                .excerpts
106                .iter()
107                .map(|e| zeta_prompt::RelatedExcerpt {
108                    row_range: e.row_range.clone(),
109                    text: e.text.clone().into(),
110                })
111                .collect(),
112        }
113    }
114}
115
116#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
117pub struct CapturedRelatedExcerpt {
118    pub row_range: Range<u32>,
119    pub text: String,
120}
121
122const REASONING_HEADING: &str = "Reasoning";
123const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
124const EDIT_HISTORY_HEADING: &str = "Edit History";
125const CURSOR_POSITION_HEADING: &str = "Cursor Position";
126const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
127const REJECTED_PATCH_HEADING: &str = "Rejected Patch";
128
129#[derive(Serialize, Deserialize)]
130struct FrontMatter<'a> {
131    repository_url: Cow<'a, str>,
132    revision: Cow<'a, str>,
133    #[serde(default, skip_serializing_if = "Vec::is_empty")]
134    tags: Vec<String>,
135}
136
137impl ExampleSpec {
138    /// Generate a sanitized filename for this example.
139    pub fn filename(&self) -> String {
140        self.name
141            .chars()
142            .map(|c| match c {
143                ' ' | ':' | '~' | '^' | '?' | '*' | '[' | '\\' | '@' | '{' | '/' | '<' | '>'
144                | '|' | '"' => '-',
145                c => c,
146            })
147            .collect()
148    }
149
150    /// Format this example spec as markdown.
151    pub fn to_markdown(&self) -> String {
152        use std::fmt::Write as _;
153
154        let front_matter = FrontMatter {
155            repository_url: Cow::Borrowed(&self.repository_url),
156            revision: Cow::Borrowed(&self.revision),
157            tags: self.tags.clone(),
158        };
159        let front_matter_toml =
160            toml::to_string_pretty(&front_matter).unwrap_or_else(|_| String::new());
161
162        let mut markdown = String::new();
163
164        _ = writeln!(markdown, "+++");
165        markdown.push_str(&front_matter_toml);
166        if !markdown.ends_with('\n') {
167            markdown.push('\n');
168        }
169        _ = writeln!(markdown, "+++");
170        markdown.push('\n');
171
172        _ = writeln!(markdown, "# {}", self.name);
173        markdown.push('\n');
174
175        if let Some(reasoning) = &self.reasoning {
176            _ = writeln!(markdown, "## {}", REASONING_HEADING);
177            markdown.push('\n');
178            markdown.push_str(reasoning);
179            if !markdown.ends_with('\n') {
180                markdown.push('\n');
181            }
182            markdown.push('\n');
183        }
184
185        if !self.uncommitted_diff.is_empty() {
186            _ = writeln!(markdown, "## {}", UNCOMMITTED_DIFF_HEADING);
187            _ = writeln!(markdown);
188            _ = writeln!(markdown, "```diff");
189            markdown.push_str(&self.uncommitted_diff);
190            if !markdown.ends_with('\n') {
191                markdown.push('\n');
192            }
193            _ = writeln!(markdown, "```");
194            markdown.push('\n');
195        }
196
197        _ = writeln!(markdown, "## {}", EDIT_HISTORY_HEADING);
198        _ = writeln!(markdown);
199
200        if self.edit_history.is_empty() {
201            _ = writeln!(markdown, "(No edit history)");
202            _ = writeln!(markdown);
203        } else {
204            _ = writeln!(markdown, "```diff");
205            markdown.push_str(&self.edit_history);
206            if !markdown.ends_with('\n') {
207                markdown.push('\n');
208            }
209            _ = writeln!(markdown, "```");
210            markdown.push('\n');
211        }
212
213        _ = writeln!(markdown, "## {}", CURSOR_POSITION_HEADING);
214        _ = writeln!(markdown);
215        _ = writeln!(markdown, "```{}", self.cursor_path.to_string_lossy());
216        markdown.push_str(&self.cursor_position);
217        if !markdown.ends_with('\n') {
218            markdown.push('\n');
219        }
220        _ = writeln!(markdown, "```");
221        markdown.push('\n');
222
223        _ = writeln!(markdown, "## {}", EXPECTED_PATCH_HEADING);
224        markdown.push('\n');
225        for patch in &self.expected_patches {
226            _ = writeln!(markdown, "```diff");
227            markdown.push_str(patch);
228            if !markdown.ends_with('\n') {
229                markdown.push('\n');
230            }
231            _ = writeln!(markdown, "```");
232            markdown.push('\n');
233        }
234
235        if let Some(rejected_patch) = &self.rejected_patch {
236            _ = writeln!(markdown, "## {}", REJECTED_PATCH_HEADING);
237            markdown.push('\n');
238            _ = writeln!(markdown, "```diff");
239            markdown.push_str(rejected_patch);
240            if !markdown.ends_with('\n') {
241                markdown.push('\n');
242            }
243            _ = writeln!(markdown, "```");
244            markdown.push('\n');
245        }
246
247        markdown
248    }
249
250    /// Parse an example spec from markdown.
251    pub fn from_markdown(mut input: &str) -> anyhow::Result<Self> {
252        use pulldown_cmark::{CodeBlockKind, CowStr, Event, HeadingLevel, Parser, Tag, TagEnd};
253
254        let mut spec = ExampleSpec {
255            name: String::new(),
256            repository_url: String::new(),
257            revision: String::new(),
258            tags: Vec::new(),
259            reasoning: None,
260            uncommitted_diff: String::new(),
261            cursor_path: Path::new("").into(),
262            cursor_position: String::new(),
263            edit_history: String::new(),
264            expected_patches: Vec::new(),
265            rejected_patch: None,
266            captured_prompt_input: None,
267            telemetry: None,
268            human_feedback: Vec::new(),
269            rating: None,
270        };
271
272        if let Some(rest) = input.strip_prefix("+++\n")
273            && let Some((front_matter, rest)) = rest.split_once("+++\n")
274        {
275            if let Ok(data) = toml::from_str::<FrontMatter<'_>>(front_matter) {
276                spec.repository_url = data.repository_url.into_owned();
277                spec.revision = data.revision.into_owned();
278                spec.tags = data.tags;
279            }
280            input = rest.trim_start();
281        }
282
283        let parser = Parser::new(input);
284        let mut text = String::new();
285        let mut block_info: CowStr = "".into();
286
287        #[derive(PartialEq)]
288        enum Section {
289            Start,
290            UncommittedDiff,
291            EditHistory,
292            CursorPosition,
293            ExpectedPatch,
294            RejectedPatch,
295            Other,
296        }
297
298        let mut current_section = Section::Start;
299
300        for event in parser {
301            match event {
302                Event::Text(line) => {
303                    text.push_str(&line);
304                }
305                Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
306                    spec.name = mem::take(&mut text);
307                }
308                Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
309                    let title = mem::take(&mut text);
310                    current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
311                        Section::UncommittedDiff
312                    } else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
313                        Section::EditHistory
314                    } else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
315                        Section::CursorPosition
316                    } else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
317                        Section::ExpectedPatch
318                    } else if title.eq_ignore_ascii_case(REJECTED_PATCH_HEADING) {
319                        Section::RejectedPatch
320                    } else {
321                        Section::Other
322                    };
323                }
324                Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
325                    mem::take(&mut text);
326                }
327                Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
328                    mem::take(&mut text);
329                }
330                Event::End(TagEnd::Heading(level)) => {
331                    anyhow::bail!("Unexpected heading level: {level}");
332                }
333                Event::Start(Tag::CodeBlock(kind)) => {
334                    match kind {
335                        CodeBlockKind::Fenced(info) => {
336                            block_info = info;
337                        }
338                        CodeBlockKind::Indented => {
339                            anyhow::bail!("Unexpected indented codeblock");
340                        }
341                    };
342                }
343                Event::Start(_) => {
344                    text.clear();
345                    block_info = "".into();
346                }
347                Event::End(TagEnd::CodeBlock) => {
348                    let block_info = block_info.trim();
349                    match current_section {
350                        Section::UncommittedDiff => {
351                            spec.uncommitted_diff = mem::take(&mut text);
352                        }
353                        Section::EditHistory => {
354                            spec.edit_history.push_str(&mem::take(&mut text));
355                        }
356                        Section::CursorPosition => {
357                            spec.cursor_path = Path::new(block_info).into();
358                            spec.cursor_position = mem::take(&mut text);
359                        }
360                        Section::ExpectedPatch => {
361                            spec.expected_patches.push(mem::take(&mut text));
362                        }
363                        Section::RejectedPatch => {
364                            spec.rejected_patch = Some(mem::take(&mut text));
365                        }
366                        Section::Start | Section::Other => {}
367                    }
368                }
369                _ => {}
370            }
371        }
372
373        if spec.cursor_path.as_ref() == Path::new("") || spec.cursor_position.is_empty() {
374            anyhow::bail!("Missing cursor position codeblock");
375        }
376
377        Ok(spec)
378    }
379
380    /// Returns the excerpt of text around the cursor, and the offset of the cursor within that
381    /// excerpt.
382    ///
383    /// The cursor's position is marked with a special comment that appears
384    /// below the cursor line, which contains the string `[CURSOR_POSITION]`,
385    /// preceded by an arrow marking the cursor's column. The arrow can be
386    /// either:
387    /// - `^` - The cursor column is at the position of the `^` character (pointing up to the cursor)
388    /// - `<` - The cursor column is at the first non-whitespace character on that line.
389    pub fn cursor_excerpt(&self) -> Result<(String, usize)> {
390        let input = &self.cursor_position;
391
392        // Check for inline cursor marker first
393        if let Some(inline_offset) = input.find(INLINE_CURSOR_MARKER) {
394            let excerpt = input[..inline_offset].to_string()
395                + &input[inline_offset + INLINE_CURSOR_MARKER.len()..];
396            return Ok((excerpt, inline_offset));
397        }
398
399        let marker_offset = input
400            .find(CURSOR_POSITION_MARKER)
401            .context("missing [CURSOR_POSITION] marker")?;
402        let marker_line_start = input[..marker_offset]
403            .rfind('\n')
404            .map(|pos| pos + 1)
405            .unwrap_or(0);
406        let marker_line_end = input[marker_line_start..]
407            .find('\n')
408            .map(|pos| marker_line_start + pos + 1)
409            .unwrap_or(input.len());
410        let marker_line = &input[marker_line_start..marker_line_end].trim_end_matches('\n');
411
412        let cursor_column = if let Some(cursor_offset) = marker_line.find('^') {
413            cursor_offset
414        } else if let Some(less_than_pos) = marker_line.find('<') {
415            marker_line
416                .find(|c: char| !c.is_whitespace())
417                .unwrap_or(less_than_pos)
418        } else {
419            anyhow::bail!(
420                "cursor position marker line must contain '^' or '<' before [CURSOR_POSITION]"
421            );
422        };
423
424        let mut excerpt = input[..marker_line_start].to_string() + &input[marker_line_end..];
425        excerpt.truncate(excerpt.trim_end_matches('\n').len());
426
427        // The cursor is on the line above the marker line.
428        let cursor_line_end = marker_line_start.saturating_sub(1);
429        let cursor_line_start = excerpt[..cursor_line_end]
430            .rfind('\n')
431            .map(|pos| pos + 1)
432            .unwrap_or(0);
433        let cursor_offset = cursor_line_start + cursor_column;
434
435        Ok((excerpt, cursor_offset))
436    }
437
438    /// Sets the cursor position excerpt from a plain excerpt and cursor byte offset.
439    ///
440    /// The `line_comment_prefix` is used to format the marker line as a comment.
441    /// If the cursor column is less than the comment prefix length, the `<` format is used.
442    /// Otherwise, the `^` format is used.
443    pub fn set_cursor_excerpt(
444        &mut self,
445        excerpt: &str,
446        cursor_offset: usize,
447        line_comment_prefix: &str,
448    ) {
449        // Find which line the cursor is on and its column
450        let cursor_line_start = excerpt[..cursor_offset]
451            .rfind('\n')
452            .map(|pos| pos + 1)
453            .unwrap_or(0);
454        let cursor_line_end = excerpt[cursor_line_start..]
455            .find('\n')
456            .map(|pos| cursor_line_start + pos + 1)
457            .unwrap_or(excerpt.len());
458        let cursor_line = &excerpt[cursor_line_start..cursor_line_end];
459        let cursor_line_indent = &cursor_line[..cursor_line.len() - cursor_line.trim_start().len()];
460        let cursor_column = cursor_offset - cursor_line_start;
461
462        // Build the marker line
463        let mut marker_line = String::new();
464        if cursor_column < line_comment_prefix.len() {
465            for _ in 0..cursor_column {
466                marker_line.push(' ');
467            }
468            marker_line.push_str(line_comment_prefix);
469            write!(marker_line, " <{}", CURSOR_POSITION_MARKER).unwrap();
470        } else {
471            if cursor_column >= cursor_line_indent.len() + line_comment_prefix.len() {
472                marker_line.push_str(cursor_line_indent);
473            }
474            marker_line.push_str(line_comment_prefix);
475            while marker_line.len() < cursor_column {
476                marker_line.push(' ');
477            }
478            write!(marker_line, "^{}", CURSOR_POSITION_MARKER).unwrap();
479        }
480
481        // Build the final cursor_position string
482        let mut result = String::with_capacity(excerpt.len() + marker_line.len() + 2);
483        result.push_str(&excerpt[..cursor_line_end]);
484        if !result.ends_with('\n') {
485            result.push('\n');
486        }
487        result.push_str(&marker_line);
488        if cursor_line_end < excerpt.len() {
489            result.push('\n');
490            result.push_str(&excerpt[cursor_line_end..]);
491        }
492
493        self.cursor_position = result;
494    }
495
496    /// Returns all of the possible expected patches for this example, each with an optional
497    /// cursor offset.
498    ///
499    /// The cursor offset is an offset within the new text (after applying the patch), relative
500    /// to the start of the hunk.
501    ///
502    /// In the serialized representation of this example, the cursor position is represented
503    /// using a comment line in the diff, beginning with `#`, and containing a `[CURSOR_POSITION]`
504    /// marker with the same format as the [`Self::cursor_excerpt`].
505    pub fn expected_patches_with_cursor_positions(&self) -> Vec<(String, Option<usize>)> {
506        self.expected_patches
507            .iter()
508            .map(|patch| {
509                let mut clean_patch = String::new();
510                let mut cursor_offset: Option<usize> = None;
511                let mut line_start_offset = 0usize;
512                let mut prev_line_start_offset = 0usize;
513
514                for line in patch.lines() {
515                    let diff_line = DiffLine::parse(line);
516
517                    match &diff_line {
518                        DiffLine::Garbage(content)
519                            if content.starts_with('#')
520                                && content.contains(CURSOR_POSITION_MARKER) =>
521                        {
522                            let caret_column = if let Some(caret_pos) = content.find('^') {
523                                caret_pos
524                            } else if let Some(_) = content.find('<') {
525                                0
526                            } else {
527                                continue;
528                            };
529                            let cursor_column = caret_column.saturating_sub('#'.len_utf8());
530                            cursor_offset = Some(prev_line_start_offset + cursor_column);
531                        }
532                        _ => {
533                            if !clean_patch.is_empty() {
534                                clean_patch.push('\n');
535                            }
536                            clean_patch.push_str(line);
537
538                            match diff_line {
539                                DiffLine::Addition(content) | DiffLine::Context(content) => {
540                                    prev_line_start_offset = line_start_offset;
541                                    line_start_offset += content.len() + 1;
542                                }
543                                _ => {}
544                            }
545                        }
546                    }
547                }
548
549                if patch.ends_with('\n') && !clean_patch.is_empty() {
550                    clean_patch.push('\n');
551                }
552
553                (clean_patch, cursor_offset)
554            })
555            .collect()
556    }
557
558    pub fn set_expected_patches_with_cursor_positions(
559        &mut self,
560        patches: Vec<(String, Option<usize>)>,
561    ) {
562        self.expected_patches = patches
563            .into_iter()
564            .map(|(patch, cursor_editable_region_offset)| {
565                let Some(cursor_offset) = cursor_editable_region_offset else {
566                    return patch;
567                };
568
569                let mut result = String::new();
570                let mut line_start_offset = 0usize;
571
572                for line in patch.lines() {
573                    if !result.is_empty() {
574                        result.push('\n');
575                    }
576                    result.push_str(line);
577
578                    match DiffLine::parse(line) {
579                        DiffLine::Addition(content) => {
580                            let line_end_offset = line_start_offset + content.len();
581
582                            if cursor_offset >= line_start_offset
583                                && cursor_offset <= line_end_offset
584                            {
585                                let cursor_column = cursor_offset - line_start_offset;
586
587                                result.push('\n');
588                                result.push('#');
589                                for _ in 0..cursor_column {
590                                    result.push(' ');
591                                }
592                                write!(result, "^{}", CURSOR_POSITION_MARKER).unwrap();
593                            }
594
595                            line_start_offset = line_end_offset + 1;
596                        }
597                        DiffLine::Context(content) => {
598                            line_start_offset += content.len() + 1;
599                        }
600                        _ => {}
601                    }
602                }
603
604                if patch.ends_with('\n') {
605                    result.push('\n');
606                }
607
608                result
609            })
610            .collect();
611    }
612}
613
614#[cfg(test)]
615mod tests {
616    use super::*;
617    use indoc::indoc;
618
619    #[test]
620    fn test_cursor_excerpt_with_caret() {
621        let mut spec = ExampleSpec {
622            name: String::new(),
623            repository_url: String::new(),
624            revision: String::new(),
625            tags: Vec::new(),
626            reasoning: None,
627            uncommitted_diff: String::new(),
628            cursor_path: Path::new("test.rs").into(),
629            cursor_position: String::new(),
630            edit_history: String::new(),
631            expected_patches: Vec::new(),
632            rejected_patch: None,
633            captured_prompt_input: None,
634            telemetry: None,
635            human_feedback: Vec::new(),
636            rating: None,
637        };
638
639        // Cursor before `42`
640        let excerpt = indoc! {"
641            fn main() {
642                let x = 42;
643                println!(\"{}\", x);
644            }"
645        };
646        let offset = excerpt.find("42").unwrap();
647        let position_string = indoc! {"
648            fn main() {
649                let x = 42;
650                //      ^[CURSOR_POSITION]
651                println!(\"{}\", x);
652            }"
653        }
654        .to_string();
655
656        spec.set_cursor_excerpt(excerpt, offset, "//");
657        assert_eq!(spec.cursor_position, position_string);
658        assert_eq!(
659            spec.cursor_excerpt().unwrap(),
660            (excerpt.to_string(), offset)
661        );
662
663        // Cursor after `l` in `let`
664        let offset = excerpt.find("et x").unwrap();
665        let position_string = indoc! {"
666            fn main() {
667                let x = 42;
668            //   ^[CURSOR_POSITION]
669                println!(\"{}\", x);
670            }"
671        }
672        .to_string();
673
674        spec.set_cursor_excerpt(excerpt, offset, "//");
675        assert_eq!(spec.cursor_position, position_string);
676        assert_eq!(
677            spec.cursor_excerpt().unwrap(),
678            (excerpt.to_string(), offset)
679        );
680
681        // Cursor before `let`
682        let offset = excerpt.find("let").unwrap();
683        let position_string = indoc! {"
684            fn main() {
685                let x = 42;
686            //  ^[CURSOR_POSITION]
687                println!(\"{}\", x);
688            }"
689        }
690        .to_string();
691
692        spec.set_cursor_excerpt(excerpt, offset, "//");
693        assert_eq!(spec.cursor_position, position_string);
694        assert_eq!(
695            spec.cursor_excerpt().unwrap(),
696            (excerpt.to_string(), offset)
697        );
698
699        // Cursor at beginning of the line with `let`
700        let offset = excerpt.find("    let").unwrap();
701        let position_string = indoc! {"
702            fn main() {
703                let x = 42;
704            // <[CURSOR_POSITION]
705                println!(\"{}\", x);
706            }"
707        }
708        .to_string();
709
710        spec.set_cursor_excerpt(excerpt, offset, "//");
711        assert_eq!(spec.cursor_position, position_string);
712        assert_eq!(
713            spec.cursor_excerpt().unwrap(),
714            (excerpt.to_string(), offset)
715        );
716
717        // Cursor at end of line, after the semicolon
718        let offset = excerpt.find(';').unwrap() + 1;
719        let position_string = indoc! {"
720            fn main() {
721                let x = 42;
722                //         ^[CURSOR_POSITION]
723                println!(\"{}\", x);
724            }"
725        }
726        .to_string();
727
728        spec.set_cursor_excerpt(excerpt, offset, "//");
729        assert_eq!(spec.cursor_position, position_string);
730        assert_eq!(
731            spec.cursor_excerpt().unwrap(),
732            (excerpt.to_string(), offset)
733        );
734
735        // Caret at end of file (no trailing newline)
736        let excerpt = indoc! {"
737            fn main() {
738                let x = 42;"
739        };
740        let offset = excerpt.find(';').unwrap() + 1;
741        let position_string = indoc! {"
742            fn main() {
743                let x = 42;
744                //         ^[CURSOR_POSITION]"
745        }
746        .to_string();
747
748        spec.set_cursor_excerpt(excerpt, offset, "//");
749        assert_eq!(spec.cursor_position, position_string);
750        assert_eq!(
751            spec.cursor_excerpt().unwrap(),
752            (excerpt.to_string(), offset)
753        );
754    }
755
756    #[test]
757    fn test_cursor_excerpt_with_inline_marker() {
758        let mut spec = ExampleSpec {
759            name: String::new(),
760            repository_url: String::new(),
761            revision: String::new(),
762            tags: Vec::new(),
763            reasoning: None,
764            uncommitted_diff: String::new(),
765            cursor_path: Path::new("test.rs").into(),
766            cursor_position: String::new(),
767            edit_history: String::new(),
768            expected_patches: Vec::new(),
769            rejected_patch: None,
770            captured_prompt_input: None,
771            telemetry: None,
772            human_feedback: Vec::new(),
773            rating: None,
774        };
775
776        // Cursor before `42` using inline marker
777        spec.cursor_position = indoc! {"
778            fn main() {
779                let x = <|user_cursor|>42;
780                println!(\"{}\", x);
781            }"
782        }
783        .to_string();
784
785        let expected_excerpt = indoc! {"
786            fn main() {
787                let x = 42;
788                println!(\"{}\", x);
789            }"
790        };
791        let expected_offset = expected_excerpt.find("42").unwrap();
792
793        assert_eq!(
794            spec.cursor_excerpt().unwrap(),
795            (expected_excerpt.to_string(), expected_offset)
796        );
797
798        // Cursor at beginning of line
799        spec.cursor_position = indoc! {"
800            fn main() {
801            <|user_cursor|>    let x = 42;
802            }"
803        }
804        .to_string();
805
806        let expected_excerpt = indoc! {"
807            fn main() {
808                let x = 42;
809            }"
810        };
811        let expected_offset = expected_excerpt.find("    let").unwrap();
812
813        assert_eq!(
814            spec.cursor_excerpt().unwrap(),
815            (expected_excerpt.to_string(), expected_offset)
816        );
817
818        // Cursor at end of file
819        spec.cursor_position = "fn main() {}<|user_cursor|>".to_string();
820        let expected_excerpt = "fn main() {}";
821        let expected_offset = expected_excerpt.len();
822
823        assert_eq!(
824            spec.cursor_excerpt().unwrap(),
825            (expected_excerpt.to_string(), expected_offset)
826        );
827    }
828
829    #[test]
830    fn test_expected_patches_with_cursor_positions() {
831        let mut spec = ExampleSpec {
832            name: String::new(),
833            repository_url: String::new(),
834            revision: String::new(),
835            tags: Vec::new(),
836            reasoning: None,
837            uncommitted_diff: String::new(),
838            cursor_path: Path::new("test.rs").into(),
839            cursor_position: String::new(),
840            edit_history: String::new(),
841            expected_patches: Vec::new(),
842            rejected_patch: None,
843            captured_prompt_input: None,
844            telemetry: None,
845            human_feedback: Vec::new(),
846            rating: None,
847        };
848
849        let new_content = indoc! {r#"
850            // prints a greeting
851            fn main() {
852                println!("hello, {}", );
853                let x = 42;
854            }
855        "#};
856        let cursor_offset = new_content.find(");").unwrap();
857
858        let clean_patch = indoc! {r#"
859            --- a/test.rs
860            +++ b/test.rs
861            @@ -1,3 +1,4 @@
862            +// prints a greeting
863             fn main() {
864            -    println!("hi");
865            +    println!("hello, {}", );
866                 let x = 42;
867             }
868        "#}
869        .to_string();
870
871        let encoded_patch = indoc! {r#"
872            --- a/test.rs
873            +++ b/test.rs
874            @@ -1,3 +1,4 @@
875            +// prints a greeting
876             fn main() {
877            -    println!("hi");
878            +    println!("hello, {}", );
879            #                          ^[CURSOR_POSITION]
880                 let x = 42;
881             }
882        "#}
883        .to_string();
884
885        spec.set_expected_patches_with_cursor_positions(vec![(
886            clean_patch.clone(),
887            Some(cursor_offset),
888        )]);
889        assert_eq!(spec.expected_patches, vec![encoded_patch]);
890
891        let results = spec.expected_patches_with_cursor_positions();
892        assert_eq!(results, vec![(clean_patch.clone(), Some(cursor_offset))]);
893
894        spec.set_expected_patches_with_cursor_positions(vec![(clean_patch.clone(), None)]);
895        assert_eq!(spec.expected_patches, vec![clean_patch.clone()]);
896
897        let results = spec.expected_patches_with_cursor_positions();
898        assert_eq!(results, vec![(clean_patch, None)]);
899    }
900}