example_spec.rs

  1use anyhow::{Context as _, Result};
  2use serde::{Deserialize, Serialize};
  3use std::{borrow::Cow, fmt::Write as _, mem, ops::Range, path::Path, sync::Arc};
  4
  5pub const CURSOR_POSITION_MARKER: &str = "[CURSOR_POSITION]";
  6pub const INLINE_CURSOR_MARKER: &str = "<|user_cursor|>";
  7
  8/// Maximum cursor file size to capture (64KB).
  9/// Files larger than this will not have their content captured,
 10/// falling back to git-based loading.
 11pub const MAX_CURSOR_FILE_SIZE: usize = 64 * 1024;
 12
 13#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
 14pub struct ExampleSpec {
 15    #[serde(default)]
 16    pub name: String,
 17    pub repository_url: String,
 18    pub revision: String,
 19    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 20    pub tags: Vec<String>,
 21    #[serde(default, skip_serializing_if = "Option::is_none")]
 22    pub reasoning: Option<String>,
 23    #[serde(default)]
 24    pub uncommitted_diff: String,
 25    pub cursor_path: Arc<Path>,
 26    pub cursor_position: String,
 27    pub edit_history: String,
 28    pub expected_patches: Vec<String>,
 29    #[serde(default, skip_serializing_if = "Option::is_none")]
 30    pub rejected_patch: Option<String>,
 31    #[serde(default, skip_serializing_if = "Option::is_none")]
 32    pub captured_prompt_input: Option<CapturedPromptInput>,
 33}
 34
 35/// All data needed to run format_prompt without loading the project.
 36#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
 37pub struct CapturedPromptInput {
 38    pub cursor_file_content: String,
 39    pub cursor_offset: usize,
 40    pub cursor_row: u32,
 41    pub cursor_column: u32,
 42    pub events: Vec<CapturedEvent>,
 43    pub related_files: Vec<CapturedRelatedFile>,
 44}
 45
 46#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
 47pub struct CapturedEvent {
 48    pub path: Arc<Path>,
 49    pub old_path: Arc<Path>,
 50    pub diff: String,
 51    pub predicted: bool,
 52    pub in_open_source_repo: bool,
 53}
 54
 55impl CapturedEvent {
 56    pub fn to_event(&self) -> zeta_prompt::Event {
 57        zeta_prompt::Event::BufferChange {
 58            path: self.path.clone(),
 59            old_path: self.old_path.clone(),
 60            diff: self.diff.clone(),
 61            predicted: self.predicted,
 62            in_open_source_repo: self.in_open_source_repo,
 63        }
 64    }
 65}
 66
 67#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
 68pub struct CapturedRelatedFile {
 69    pub path: Arc<Path>,
 70    pub max_row: u32,
 71    pub excerpts: Vec<CapturedRelatedExcerpt>,
 72}
 73
 74impl CapturedRelatedFile {
 75    pub fn to_related_file(&self) -> zeta_prompt::RelatedFile {
 76        zeta_prompt::RelatedFile {
 77            path: self.path.clone(),
 78            max_row: self.max_row,
 79            excerpts: self
 80                .excerpts
 81                .iter()
 82                .map(|e| zeta_prompt::RelatedExcerpt {
 83                    row_range: e.row_range.clone(),
 84                    text: e.text.clone().into(),
 85                })
 86                .collect(),
 87        }
 88    }
 89}
 90
 91#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
 92pub struct CapturedRelatedExcerpt {
 93    pub row_range: Range<u32>,
 94    pub text: String,
 95}
 96
 97const REASONING_HEADING: &str = "Reasoning";
 98const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
 99const EDIT_HISTORY_HEADING: &str = "Edit History";
100const CURSOR_POSITION_HEADING: &str = "Cursor Position";
101const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
102const REJECTED_PATCH_HEADING: &str = "Rejected Patch";
103
104#[derive(Serialize, Deserialize)]
105struct FrontMatter<'a> {
106    repository_url: Cow<'a, str>,
107    revision: Cow<'a, str>,
108    #[serde(default, skip_serializing_if = "Vec::is_empty")]
109    tags: Vec<String>,
110}
111
112impl ExampleSpec {
113    /// Generate a sanitized filename for this example.
114    pub fn filename(&self) -> String {
115        self.name
116            .chars()
117            .map(|c| match c {
118                ' ' | ':' | '~' | '^' | '?' | '*' | '[' | '\\' | '@' | '{' | '/' | '<' | '>'
119                | '|' | '"' => '-',
120                c => c,
121            })
122            .collect()
123    }
124
125    /// Format this example spec as markdown.
126    pub fn to_markdown(&self) -> String {
127        use std::fmt::Write as _;
128
129        let front_matter = FrontMatter {
130            repository_url: Cow::Borrowed(&self.repository_url),
131            revision: Cow::Borrowed(&self.revision),
132            tags: self.tags.clone(),
133        };
134        let front_matter_toml =
135            toml::to_string_pretty(&front_matter).unwrap_or_else(|_| String::new());
136
137        let mut markdown = String::new();
138
139        _ = writeln!(markdown, "+++");
140        markdown.push_str(&front_matter_toml);
141        if !markdown.ends_with('\n') {
142            markdown.push('\n');
143        }
144        _ = writeln!(markdown, "+++");
145        markdown.push('\n');
146
147        _ = writeln!(markdown, "# {}", self.name);
148        markdown.push('\n');
149
150        if let Some(reasoning) = &self.reasoning {
151            _ = writeln!(markdown, "## {}", REASONING_HEADING);
152            markdown.push('\n');
153            markdown.push_str(reasoning);
154            if !markdown.ends_with('\n') {
155                markdown.push('\n');
156            }
157            markdown.push('\n');
158        }
159
160        if !self.uncommitted_diff.is_empty() {
161            _ = writeln!(markdown, "## {}", UNCOMMITTED_DIFF_HEADING);
162            _ = writeln!(markdown);
163            _ = writeln!(markdown, "```diff");
164            markdown.push_str(&self.uncommitted_diff);
165            if !markdown.ends_with('\n') {
166                markdown.push('\n');
167            }
168            _ = writeln!(markdown, "```");
169            markdown.push('\n');
170        }
171
172        _ = writeln!(markdown, "## {}", EDIT_HISTORY_HEADING);
173        _ = writeln!(markdown);
174
175        if self.edit_history.is_empty() {
176            _ = writeln!(markdown, "(No edit history)");
177            _ = writeln!(markdown);
178        } else {
179            _ = writeln!(markdown, "```diff");
180            markdown.push_str(&self.edit_history);
181            if !markdown.ends_with('\n') {
182                markdown.push('\n');
183            }
184            _ = writeln!(markdown, "```");
185            markdown.push('\n');
186        }
187
188        _ = writeln!(markdown, "## {}", CURSOR_POSITION_HEADING);
189        _ = writeln!(markdown);
190        _ = writeln!(markdown, "```{}", self.cursor_path.to_string_lossy());
191        markdown.push_str(&self.cursor_position);
192        if !markdown.ends_with('\n') {
193            markdown.push('\n');
194        }
195        _ = writeln!(markdown, "```");
196        markdown.push('\n');
197
198        _ = writeln!(markdown, "## {}", EXPECTED_PATCH_HEADING);
199        markdown.push('\n');
200        for patch in &self.expected_patches {
201            _ = writeln!(markdown, "```diff");
202            markdown.push_str(patch);
203            if !markdown.ends_with('\n') {
204                markdown.push('\n');
205            }
206            _ = writeln!(markdown, "```");
207            markdown.push('\n');
208        }
209
210        if let Some(rejected_patch) = &self.rejected_patch {
211            _ = writeln!(markdown, "## {}", REJECTED_PATCH_HEADING);
212            markdown.push('\n');
213            _ = writeln!(markdown, "```diff");
214            markdown.push_str(rejected_patch);
215            if !markdown.ends_with('\n') {
216                markdown.push('\n');
217            }
218            _ = writeln!(markdown, "```");
219            markdown.push('\n');
220        }
221
222        markdown
223    }
224
225    /// Parse an example spec from markdown.
226    pub fn from_markdown(mut input: &str) -> anyhow::Result<Self> {
227        use pulldown_cmark::{CodeBlockKind, CowStr, Event, HeadingLevel, Parser, Tag, TagEnd};
228
229        let mut spec = ExampleSpec {
230            name: String::new(),
231            repository_url: String::new(),
232            revision: String::new(),
233            tags: Vec::new(),
234            reasoning: None,
235            uncommitted_diff: String::new(),
236            cursor_path: Path::new("").into(),
237            cursor_position: String::new(),
238            edit_history: String::new(),
239            expected_patches: Vec::new(),
240            rejected_patch: None,
241            captured_prompt_input: None,
242        };
243
244        if let Some(rest) = input.strip_prefix("+++\n")
245            && let Some((front_matter, rest)) = rest.split_once("+++\n")
246        {
247            if let Ok(data) = toml::from_str::<FrontMatter<'_>>(front_matter) {
248                spec.repository_url = data.repository_url.into_owned();
249                spec.revision = data.revision.into_owned();
250                spec.tags = data.tags;
251            }
252            input = rest.trim_start();
253        }
254
255        let parser = Parser::new(input);
256        let mut text = String::new();
257        let mut block_info: CowStr = "".into();
258
259        #[derive(PartialEq)]
260        enum Section {
261            Start,
262            UncommittedDiff,
263            EditHistory,
264            CursorPosition,
265            ExpectedPatch,
266            RejectedPatch,
267            Other,
268        }
269
270        let mut current_section = Section::Start;
271
272        for event in parser {
273            match event {
274                Event::Text(line) => {
275                    text.push_str(&line);
276                }
277                Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
278                    spec.name = mem::take(&mut text);
279                }
280                Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
281                    let title = mem::take(&mut text);
282                    current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
283                        Section::UncommittedDiff
284                    } else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
285                        Section::EditHistory
286                    } else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
287                        Section::CursorPosition
288                    } else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
289                        Section::ExpectedPatch
290                    } else if title.eq_ignore_ascii_case(REJECTED_PATCH_HEADING) {
291                        Section::RejectedPatch
292                    } else {
293                        Section::Other
294                    };
295                }
296                Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
297                    mem::take(&mut text);
298                }
299                Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
300                    mem::take(&mut text);
301                }
302                Event::End(TagEnd::Heading(level)) => {
303                    anyhow::bail!("Unexpected heading level: {level}");
304                }
305                Event::Start(Tag::CodeBlock(kind)) => {
306                    match kind {
307                        CodeBlockKind::Fenced(info) => {
308                            block_info = info;
309                        }
310                        CodeBlockKind::Indented => {
311                            anyhow::bail!("Unexpected indented codeblock");
312                        }
313                    };
314                }
315                Event::Start(_) => {
316                    text.clear();
317                    block_info = "".into();
318                }
319                Event::End(TagEnd::CodeBlock) => {
320                    let block_info = block_info.trim();
321                    match current_section {
322                        Section::UncommittedDiff => {
323                            spec.uncommitted_diff = mem::take(&mut text);
324                        }
325                        Section::EditHistory => {
326                            spec.edit_history.push_str(&mem::take(&mut text));
327                        }
328                        Section::CursorPosition => {
329                            spec.cursor_path = Path::new(block_info).into();
330                            spec.cursor_position = mem::take(&mut text);
331                        }
332                        Section::ExpectedPatch => {
333                            spec.expected_patches.push(mem::take(&mut text));
334                        }
335                        Section::RejectedPatch => {
336                            spec.rejected_patch = Some(mem::take(&mut text));
337                        }
338                        Section::Start | Section::Other => {}
339                    }
340                }
341                _ => {}
342            }
343        }
344
345        if spec.cursor_path.as_ref() == Path::new("") || spec.cursor_position.is_empty() {
346            anyhow::bail!("Missing cursor position codeblock");
347        }
348
349        Ok(spec)
350    }
351
352    /// Returns the excerpt of text around the cursor, and the offset of the cursor within that
353    /// excerpt.
354    ///
355    /// The cursor's position is marked with a special comment that appears
356    /// below the cursor line, which contains the string `[CURSOR_POSITION]`,
357    /// preceded by an arrow marking the cursor's column. The arrow can be
358    /// either:
359    /// - `^` - The cursor column is at the position of the `^` character (pointing up to the cursor)
360    /// - `<` - The cursor column is at the first non-whitespace character on that line.
361    pub fn cursor_excerpt(&self) -> Result<(String, usize)> {
362        let input = &self.cursor_position;
363
364        // Check for inline cursor marker first
365        if let Some(inline_offset) = input.find(INLINE_CURSOR_MARKER) {
366            let excerpt = input[..inline_offset].to_string()
367                + &input[inline_offset + INLINE_CURSOR_MARKER.len()..];
368            return Ok((excerpt, inline_offset));
369        }
370
371        let marker_offset = input
372            .find(CURSOR_POSITION_MARKER)
373            .context("missing [CURSOR_POSITION] marker")?;
374        let marker_line_start = input[..marker_offset]
375            .rfind('\n')
376            .map(|pos| pos + 1)
377            .unwrap_or(0);
378        let marker_line_end = input[marker_line_start..]
379            .find('\n')
380            .map(|pos| marker_line_start + pos + 1)
381            .unwrap_or(input.len());
382        let marker_line = &input[marker_line_start..marker_line_end].trim_end_matches('\n');
383
384        let cursor_column = if let Some(cursor_offset) = marker_line.find('^') {
385            cursor_offset
386        } else if let Some(less_than_pos) = marker_line.find('<') {
387            marker_line
388                .find(|c: char| !c.is_whitespace())
389                .unwrap_or(less_than_pos)
390        } else {
391            anyhow::bail!(
392                "cursor position marker line must contain '^' or '<' before [CURSOR_POSITION]"
393            );
394        };
395
396        let mut excerpt = input[..marker_line_start].to_string() + &input[marker_line_end..];
397        excerpt.truncate(excerpt.trim_end_matches('\n').len());
398
399        // The cursor is on the line above the marker line.
400        let cursor_line_end = marker_line_start.saturating_sub(1);
401        let cursor_line_start = excerpt[..cursor_line_end]
402            .rfind('\n')
403            .map(|pos| pos + 1)
404            .unwrap_or(0);
405        let cursor_offset = cursor_line_start + cursor_column;
406
407        Ok((excerpt, cursor_offset))
408    }
409
410    /// Sets the cursor position excerpt from a plain excerpt and cursor byte offset.
411    ///
412    /// The `line_comment_prefix` is used to format the marker line as a comment.
413    /// If the cursor column is less than the comment prefix length, the `<` format is used.
414    /// Otherwise, the `^` format is used.
415    pub fn set_cursor_excerpt(
416        &mut self,
417        excerpt: &str,
418        cursor_offset: usize,
419        line_comment_prefix: &str,
420    ) {
421        // Find which line the cursor is on and its column
422        let cursor_line_start = excerpt[..cursor_offset]
423            .rfind('\n')
424            .map(|pos| pos + 1)
425            .unwrap_or(0);
426        let cursor_line_end = excerpt[cursor_line_start..]
427            .find('\n')
428            .map(|pos| cursor_line_start + pos + 1)
429            .unwrap_or(excerpt.len());
430        let cursor_line = &excerpt[cursor_line_start..cursor_line_end];
431        let cursor_line_indent = &cursor_line[..cursor_line.len() - cursor_line.trim_start().len()];
432        let cursor_column = cursor_offset - cursor_line_start;
433
434        // Build the marker line
435        let mut marker_line = String::new();
436        if cursor_column < line_comment_prefix.len() {
437            for _ in 0..cursor_column {
438                marker_line.push(' ');
439            }
440            marker_line.push_str(line_comment_prefix);
441            write!(marker_line, " <{}", CURSOR_POSITION_MARKER).unwrap();
442        } else {
443            if cursor_column >= cursor_line_indent.len() + line_comment_prefix.len() {
444                marker_line.push_str(cursor_line_indent);
445            }
446            marker_line.push_str(line_comment_prefix);
447            while marker_line.len() < cursor_column {
448                marker_line.push(' ');
449            }
450            write!(marker_line, "^{}", CURSOR_POSITION_MARKER).unwrap();
451        }
452
453        // Build the final cursor_position string
454        let mut result = String::with_capacity(excerpt.len() + marker_line.len() + 2);
455        result.push_str(&excerpt[..cursor_line_end]);
456        if !result.ends_with('\n') {
457            result.push('\n');
458        }
459        result.push_str(&marker_line);
460        if cursor_line_end < excerpt.len() {
461            result.push('\n');
462            result.push_str(&excerpt[cursor_line_end..]);
463        }
464
465        self.cursor_position = result;
466    }
467}
468
469#[cfg(test)]
470mod tests {
471    use super::*;
472    use indoc::indoc;
473
474    #[test]
475    fn test_cursor_excerpt_with_caret() {
476        let mut spec = ExampleSpec {
477            name: String::new(),
478            repository_url: String::new(),
479            revision: String::new(),
480            tags: Vec::new(),
481            reasoning: None,
482            uncommitted_diff: String::new(),
483            cursor_path: Path::new("test.rs").into(),
484            cursor_position: String::new(),
485            edit_history: String::new(),
486            expected_patches: Vec::new(),
487            rejected_patch: None,
488            captured_prompt_input: None,
489        };
490
491        // Cursor before `42`
492        let excerpt = indoc! {"
493            fn main() {
494                let x = 42;
495                println!(\"{}\", x);
496            }"
497        };
498        let offset = excerpt.find("42").unwrap();
499        let position_string = indoc! {"
500            fn main() {
501                let x = 42;
502                //      ^[CURSOR_POSITION]
503                println!(\"{}\", x);
504            }"
505        }
506        .to_string();
507
508        spec.set_cursor_excerpt(excerpt, offset, "//");
509        assert_eq!(spec.cursor_position, position_string);
510        assert_eq!(
511            spec.cursor_excerpt().unwrap(),
512            (excerpt.to_string(), offset)
513        );
514
515        // Cursor after `l` in `let`
516        let offset = excerpt.find("et x").unwrap();
517        let position_string = indoc! {"
518            fn main() {
519                let x = 42;
520            //   ^[CURSOR_POSITION]
521                println!(\"{}\", x);
522            }"
523        }
524        .to_string();
525
526        spec.set_cursor_excerpt(excerpt, offset, "//");
527        assert_eq!(spec.cursor_position, position_string);
528        assert_eq!(
529            spec.cursor_excerpt().unwrap(),
530            (excerpt.to_string(), offset)
531        );
532
533        // Cursor before `let`
534        let offset = excerpt.find("let").unwrap();
535        let position_string = indoc! {"
536            fn main() {
537                let x = 42;
538            //  ^[CURSOR_POSITION]
539                println!(\"{}\", x);
540            }"
541        }
542        .to_string();
543
544        spec.set_cursor_excerpt(excerpt, offset, "//");
545        assert_eq!(spec.cursor_position, position_string);
546        assert_eq!(
547            spec.cursor_excerpt().unwrap(),
548            (excerpt.to_string(), offset)
549        );
550
551        // Cursor at beginning of the line with `let`
552        let offset = excerpt.find("    let").unwrap();
553        let position_string = indoc! {"
554            fn main() {
555                let x = 42;
556            // <[CURSOR_POSITION]
557                println!(\"{}\", x);
558            }"
559        }
560        .to_string();
561
562        spec.set_cursor_excerpt(excerpt, offset, "//");
563        assert_eq!(spec.cursor_position, position_string);
564        assert_eq!(
565            spec.cursor_excerpt().unwrap(),
566            (excerpt.to_string(), offset)
567        );
568
569        // Cursor at end of line, after the semicolon
570        let offset = excerpt.find(';').unwrap() + 1;
571        let position_string = indoc! {"
572            fn main() {
573                let x = 42;
574                //         ^[CURSOR_POSITION]
575                println!(\"{}\", x);
576            }"
577        }
578        .to_string();
579
580        spec.set_cursor_excerpt(excerpt, offset, "//");
581        assert_eq!(spec.cursor_position, position_string);
582        assert_eq!(
583            spec.cursor_excerpt().unwrap(),
584            (excerpt.to_string(), offset)
585        );
586
587        // Caret at end of file (no trailing newline)
588        let excerpt = indoc! {"
589            fn main() {
590                let x = 42;"
591        };
592        let offset = excerpt.find(';').unwrap() + 1;
593        let position_string = indoc! {"
594            fn main() {
595                let x = 42;
596                //         ^[CURSOR_POSITION]"
597        }
598        .to_string();
599
600        spec.set_cursor_excerpt(excerpt, offset, "//");
601        assert_eq!(spec.cursor_position, position_string);
602        assert_eq!(
603            spec.cursor_excerpt().unwrap(),
604            (excerpt.to_string(), offset)
605        );
606    }
607
608    #[test]
609    fn test_cursor_excerpt_with_inline_marker() {
610        let mut spec = ExampleSpec {
611            name: String::new(),
612            repository_url: String::new(),
613            revision: String::new(),
614            tags: Vec::new(),
615            reasoning: None,
616            uncommitted_diff: String::new(),
617            cursor_path: Path::new("test.rs").into(),
618            cursor_position: String::new(),
619            edit_history: String::new(),
620            expected_patches: Vec::new(),
621            rejected_patch: None,
622            captured_prompt_input: None,
623        };
624
625        // Cursor before `42` using inline marker
626        spec.cursor_position = indoc! {"
627            fn main() {
628                let x = <|user_cursor|>42;
629                println!(\"{}\", x);
630            }"
631        }
632        .to_string();
633
634        let expected_excerpt = indoc! {"
635            fn main() {
636                let x = 42;
637                println!(\"{}\", x);
638            }"
639        };
640        let expected_offset = expected_excerpt.find("42").unwrap();
641
642        assert_eq!(
643            spec.cursor_excerpt().unwrap(),
644            (expected_excerpt.to_string(), expected_offset)
645        );
646
647        // Cursor at beginning of line
648        spec.cursor_position = indoc! {"
649            fn main() {
650            <|user_cursor|>    let x = 42;
651            }"
652        }
653        .to_string();
654
655        let expected_excerpt = indoc! {"
656            fn main() {
657                let x = 42;
658            }"
659        };
660        let expected_offset = expected_excerpt.find("    let").unwrap();
661
662        assert_eq!(
663            spec.cursor_excerpt().unwrap(),
664            (expected_excerpt.to_string(), expected_offset)
665        );
666
667        // Cursor at end of file
668        spec.cursor_position = "fn main() {}<|user_cursor|>".to_string();
669        let expected_excerpt = "fn main() {}";
670        let expected_offset = expected_excerpt.len();
671
672        assert_eq!(
673            spec.cursor_excerpt().unwrap(),
674            (expected_excerpt.to_string(), expected_offset)
675        );
676    }
677}