1use anyhow::{Context as _, Result};
2use serde::{Deserialize, Serialize};
3use std::{borrow::Cow, fmt::Write as _, mem, ops::Range, path::Path, sync::Arc};
4
5pub const CURSOR_POSITION_MARKER: &str = "[CURSOR_POSITION]";
6pub const INLINE_CURSOR_MARKER: &str = "<|user_cursor|>";
7
8/// Maximum cursor file size to capture (64KB).
9/// Files larger than this will not have their content captured,
10/// falling back to git-based loading.
11pub const MAX_CURSOR_FILE_SIZE: usize = 64 * 1024;
12
13#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
14pub struct ExampleSpec {
15 #[serde(default)]
16 pub name: String,
17 pub repository_url: String,
18 pub revision: String,
19 #[serde(default, skip_serializing_if = "Vec::is_empty")]
20 pub tags: Vec<String>,
21 #[serde(default, skip_serializing_if = "Option::is_none")]
22 pub reasoning: Option<String>,
23 #[serde(default)]
24 pub uncommitted_diff: String,
25 pub cursor_path: Arc<Path>,
26 pub cursor_position: String,
27 pub edit_history: String,
28 pub expected_patches: Vec<String>,
29 #[serde(default, skip_serializing_if = "Option::is_none")]
30 pub rejected_patch: Option<String>,
31 #[serde(default, skip_serializing_if = "Option::is_none")]
32 pub captured_prompt_input: Option<CapturedPromptInput>,
33}
34
35/// All data needed to run format_prompt without loading the project.
36#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
37pub struct CapturedPromptInput {
38 pub cursor_file_content: String,
39 pub cursor_offset: usize,
40 pub cursor_row: u32,
41 pub cursor_column: u32,
42 pub events: Vec<CapturedEvent>,
43 pub related_files: Vec<CapturedRelatedFile>,
44}
45
46#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
47pub struct CapturedEvent {
48 pub path: Arc<Path>,
49 pub old_path: Arc<Path>,
50 pub diff: String,
51 pub predicted: bool,
52 pub in_open_source_repo: bool,
53}
54
55impl CapturedEvent {
56 pub fn to_event(&self) -> zeta_prompt::Event {
57 zeta_prompt::Event::BufferChange {
58 path: self.path.clone(),
59 old_path: self.old_path.clone(),
60 diff: self.diff.clone(),
61 predicted: self.predicted,
62 in_open_source_repo: self.in_open_source_repo,
63 }
64 }
65}
66
67#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
68pub struct CapturedRelatedFile {
69 pub path: Arc<Path>,
70 pub max_row: u32,
71 pub excerpts: Vec<CapturedRelatedExcerpt>,
72}
73
74impl CapturedRelatedFile {
75 pub fn to_related_file(&self) -> zeta_prompt::RelatedFile {
76 zeta_prompt::RelatedFile {
77 path: self.path.clone(),
78 max_row: self.max_row,
79 excerpts: self
80 .excerpts
81 .iter()
82 .map(|e| zeta_prompt::RelatedExcerpt {
83 row_range: e.row_range.clone(),
84 text: e.text.clone().into(),
85 })
86 .collect(),
87 }
88 }
89}
90
91#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
92pub struct CapturedRelatedExcerpt {
93 pub row_range: Range<u32>,
94 pub text: String,
95}
96
97const REASONING_HEADING: &str = "Reasoning";
98const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
99const EDIT_HISTORY_HEADING: &str = "Edit History";
100const CURSOR_POSITION_HEADING: &str = "Cursor Position";
101const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
102const REJECTED_PATCH_HEADING: &str = "Rejected Patch";
103
104#[derive(Serialize, Deserialize)]
105struct FrontMatter<'a> {
106 repository_url: Cow<'a, str>,
107 revision: Cow<'a, str>,
108 #[serde(default, skip_serializing_if = "Vec::is_empty")]
109 tags: Vec<String>,
110}
111
112impl ExampleSpec {
113 /// Generate a sanitized filename for this example.
114 pub fn filename(&self) -> String {
115 self.name
116 .chars()
117 .map(|c| match c {
118 ' ' | ':' | '~' | '^' | '?' | '*' | '[' | '\\' | '@' | '{' | '/' | '<' | '>'
119 | '|' | '"' => '-',
120 c => c,
121 })
122 .collect()
123 }
124
125 /// Format this example spec as markdown.
126 pub fn to_markdown(&self) -> String {
127 use std::fmt::Write as _;
128
129 let front_matter = FrontMatter {
130 repository_url: Cow::Borrowed(&self.repository_url),
131 revision: Cow::Borrowed(&self.revision),
132 tags: self.tags.clone(),
133 };
134 let front_matter_toml =
135 toml::to_string_pretty(&front_matter).unwrap_or_else(|_| String::new());
136
137 let mut markdown = String::new();
138
139 _ = writeln!(markdown, "+++");
140 markdown.push_str(&front_matter_toml);
141 if !markdown.ends_with('\n') {
142 markdown.push('\n');
143 }
144 _ = writeln!(markdown, "+++");
145 markdown.push('\n');
146
147 _ = writeln!(markdown, "# {}", self.name);
148 markdown.push('\n');
149
150 if let Some(reasoning) = &self.reasoning {
151 _ = writeln!(markdown, "## {}", REASONING_HEADING);
152 markdown.push('\n');
153 markdown.push_str(reasoning);
154 if !markdown.ends_with('\n') {
155 markdown.push('\n');
156 }
157 markdown.push('\n');
158 }
159
160 if !self.uncommitted_diff.is_empty() {
161 _ = writeln!(markdown, "## {}", UNCOMMITTED_DIFF_HEADING);
162 _ = writeln!(markdown);
163 _ = writeln!(markdown, "```diff");
164 markdown.push_str(&self.uncommitted_diff);
165 if !markdown.ends_with('\n') {
166 markdown.push('\n');
167 }
168 _ = writeln!(markdown, "```");
169 markdown.push('\n');
170 }
171
172 _ = writeln!(markdown, "## {}", EDIT_HISTORY_HEADING);
173 _ = writeln!(markdown);
174
175 if self.edit_history.is_empty() {
176 _ = writeln!(markdown, "(No edit history)");
177 _ = writeln!(markdown);
178 } else {
179 _ = writeln!(markdown, "```diff");
180 markdown.push_str(&self.edit_history);
181 if !markdown.ends_with('\n') {
182 markdown.push('\n');
183 }
184 _ = writeln!(markdown, "```");
185 markdown.push('\n');
186 }
187
188 _ = writeln!(markdown, "## {}", CURSOR_POSITION_HEADING);
189 _ = writeln!(markdown);
190 _ = writeln!(markdown, "```{}", self.cursor_path.to_string_lossy());
191 markdown.push_str(&self.cursor_position);
192 if !markdown.ends_with('\n') {
193 markdown.push('\n');
194 }
195 _ = writeln!(markdown, "```");
196 markdown.push('\n');
197
198 _ = writeln!(markdown, "## {}", EXPECTED_PATCH_HEADING);
199 markdown.push('\n');
200 for patch in &self.expected_patches {
201 _ = writeln!(markdown, "```diff");
202 markdown.push_str(patch);
203 if !markdown.ends_with('\n') {
204 markdown.push('\n');
205 }
206 _ = writeln!(markdown, "```");
207 markdown.push('\n');
208 }
209
210 if let Some(rejected_patch) = &self.rejected_patch {
211 _ = writeln!(markdown, "## {}", REJECTED_PATCH_HEADING);
212 markdown.push('\n');
213 _ = writeln!(markdown, "```diff");
214 markdown.push_str(rejected_patch);
215 if !markdown.ends_with('\n') {
216 markdown.push('\n');
217 }
218 _ = writeln!(markdown, "```");
219 markdown.push('\n');
220 }
221
222 markdown
223 }
224
225 /// Parse an example spec from markdown.
226 pub fn from_markdown(mut input: &str) -> anyhow::Result<Self> {
227 use pulldown_cmark::{CodeBlockKind, CowStr, Event, HeadingLevel, Parser, Tag, TagEnd};
228
229 let mut spec = ExampleSpec {
230 name: String::new(),
231 repository_url: String::new(),
232 revision: String::new(),
233 tags: Vec::new(),
234 reasoning: None,
235 uncommitted_diff: String::new(),
236 cursor_path: Path::new("").into(),
237 cursor_position: String::new(),
238 edit_history: String::new(),
239 expected_patches: Vec::new(),
240 rejected_patch: None,
241 captured_prompt_input: None,
242 };
243
244 if let Some(rest) = input.strip_prefix("+++\n")
245 && let Some((front_matter, rest)) = rest.split_once("+++\n")
246 {
247 if let Ok(data) = toml::from_str::<FrontMatter<'_>>(front_matter) {
248 spec.repository_url = data.repository_url.into_owned();
249 spec.revision = data.revision.into_owned();
250 spec.tags = data.tags;
251 }
252 input = rest.trim_start();
253 }
254
255 let parser = Parser::new(input);
256 let mut text = String::new();
257 let mut block_info: CowStr = "".into();
258
259 #[derive(PartialEq)]
260 enum Section {
261 Start,
262 UncommittedDiff,
263 EditHistory,
264 CursorPosition,
265 ExpectedPatch,
266 RejectedPatch,
267 Other,
268 }
269
270 let mut current_section = Section::Start;
271
272 for event in parser {
273 match event {
274 Event::Text(line) => {
275 text.push_str(&line);
276 }
277 Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
278 spec.name = mem::take(&mut text);
279 }
280 Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
281 let title = mem::take(&mut text);
282 current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
283 Section::UncommittedDiff
284 } else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
285 Section::EditHistory
286 } else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
287 Section::CursorPosition
288 } else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
289 Section::ExpectedPatch
290 } else if title.eq_ignore_ascii_case(REJECTED_PATCH_HEADING) {
291 Section::RejectedPatch
292 } else {
293 Section::Other
294 };
295 }
296 Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
297 mem::take(&mut text);
298 }
299 Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
300 mem::take(&mut text);
301 }
302 Event::End(TagEnd::Heading(level)) => {
303 anyhow::bail!("Unexpected heading level: {level}");
304 }
305 Event::Start(Tag::CodeBlock(kind)) => {
306 match kind {
307 CodeBlockKind::Fenced(info) => {
308 block_info = info;
309 }
310 CodeBlockKind::Indented => {
311 anyhow::bail!("Unexpected indented codeblock");
312 }
313 };
314 }
315 Event::Start(_) => {
316 text.clear();
317 block_info = "".into();
318 }
319 Event::End(TagEnd::CodeBlock) => {
320 let block_info = block_info.trim();
321 match current_section {
322 Section::UncommittedDiff => {
323 spec.uncommitted_diff = mem::take(&mut text);
324 }
325 Section::EditHistory => {
326 spec.edit_history.push_str(&mem::take(&mut text));
327 }
328 Section::CursorPosition => {
329 spec.cursor_path = Path::new(block_info).into();
330 spec.cursor_position = mem::take(&mut text);
331 }
332 Section::ExpectedPatch => {
333 spec.expected_patches.push(mem::take(&mut text));
334 }
335 Section::RejectedPatch => {
336 spec.rejected_patch = Some(mem::take(&mut text));
337 }
338 Section::Start | Section::Other => {}
339 }
340 }
341 _ => {}
342 }
343 }
344
345 if spec.cursor_path.as_ref() == Path::new("") || spec.cursor_position.is_empty() {
346 anyhow::bail!("Missing cursor position codeblock");
347 }
348
349 Ok(spec)
350 }
351
352 /// Returns the excerpt of text around the cursor, and the offset of the cursor within that
353 /// excerpt.
354 ///
355 /// The cursor's position is marked with a special comment that appears
356 /// below the cursor line, which contains the string `[CURSOR_POSITION]`,
357 /// preceded by an arrow marking the cursor's column. The arrow can be
358 /// either:
359 /// - `^` - The cursor column is at the position of the `^` character (pointing up to the cursor)
360 /// - `<` - The cursor column is at the first non-whitespace character on that line.
361 pub fn cursor_excerpt(&self) -> Result<(String, usize)> {
362 let input = &self.cursor_position;
363
364 // Check for inline cursor marker first
365 if let Some(inline_offset) = input.find(INLINE_CURSOR_MARKER) {
366 let excerpt = input[..inline_offset].to_string()
367 + &input[inline_offset + INLINE_CURSOR_MARKER.len()..];
368 return Ok((excerpt, inline_offset));
369 }
370
371 let marker_offset = input
372 .find(CURSOR_POSITION_MARKER)
373 .context("missing [CURSOR_POSITION] marker")?;
374 let marker_line_start = input[..marker_offset]
375 .rfind('\n')
376 .map(|pos| pos + 1)
377 .unwrap_or(0);
378 let marker_line_end = input[marker_line_start..]
379 .find('\n')
380 .map(|pos| marker_line_start + pos + 1)
381 .unwrap_or(input.len());
382 let marker_line = &input[marker_line_start..marker_line_end].trim_end_matches('\n');
383
384 let cursor_column = if let Some(cursor_offset) = marker_line.find('^') {
385 cursor_offset
386 } else if let Some(less_than_pos) = marker_line.find('<') {
387 marker_line
388 .find(|c: char| !c.is_whitespace())
389 .unwrap_or(less_than_pos)
390 } else {
391 anyhow::bail!(
392 "cursor position marker line must contain '^' or '<' before [CURSOR_POSITION]"
393 );
394 };
395
396 let mut excerpt = input[..marker_line_start].to_string() + &input[marker_line_end..];
397 excerpt.truncate(excerpt.trim_end_matches('\n').len());
398
399 // The cursor is on the line above the marker line.
400 let cursor_line_end = marker_line_start.saturating_sub(1);
401 let cursor_line_start = excerpt[..cursor_line_end]
402 .rfind('\n')
403 .map(|pos| pos + 1)
404 .unwrap_or(0);
405 let cursor_offset = cursor_line_start + cursor_column;
406
407 Ok((excerpt, cursor_offset))
408 }
409
410 /// Sets the cursor position excerpt from a plain excerpt and cursor byte offset.
411 ///
412 /// The `line_comment_prefix` is used to format the marker line as a comment.
413 /// If the cursor column is less than the comment prefix length, the `<` format is used.
414 /// Otherwise, the `^` format is used.
415 pub fn set_cursor_excerpt(
416 &mut self,
417 excerpt: &str,
418 cursor_offset: usize,
419 line_comment_prefix: &str,
420 ) {
421 // Find which line the cursor is on and its column
422 let cursor_line_start = excerpt[..cursor_offset]
423 .rfind('\n')
424 .map(|pos| pos + 1)
425 .unwrap_or(0);
426 let cursor_line_end = excerpt[cursor_line_start..]
427 .find('\n')
428 .map(|pos| cursor_line_start + pos + 1)
429 .unwrap_or(excerpt.len());
430 let cursor_line = &excerpt[cursor_line_start..cursor_line_end];
431 let cursor_line_indent = &cursor_line[..cursor_line.len() - cursor_line.trim_start().len()];
432 let cursor_column = cursor_offset - cursor_line_start;
433
434 // Build the marker line
435 let mut marker_line = String::new();
436 if cursor_column < line_comment_prefix.len() {
437 for _ in 0..cursor_column {
438 marker_line.push(' ');
439 }
440 marker_line.push_str(line_comment_prefix);
441 write!(marker_line, " <{}", CURSOR_POSITION_MARKER).unwrap();
442 } else {
443 if cursor_column >= cursor_line_indent.len() + line_comment_prefix.len() {
444 marker_line.push_str(cursor_line_indent);
445 }
446 marker_line.push_str(line_comment_prefix);
447 while marker_line.len() < cursor_column {
448 marker_line.push(' ');
449 }
450 write!(marker_line, "^{}", CURSOR_POSITION_MARKER).unwrap();
451 }
452
453 // Build the final cursor_position string
454 let mut result = String::with_capacity(excerpt.len() + marker_line.len() + 2);
455 result.push_str(&excerpt[..cursor_line_end]);
456 if !result.ends_with('\n') {
457 result.push('\n');
458 }
459 result.push_str(&marker_line);
460 if cursor_line_end < excerpt.len() {
461 result.push('\n');
462 result.push_str(&excerpt[cursor_line_end..]);
463 }
464
465 self.cursor_position = result;
466 }
467}
468
469#[cfg(test)]
470mod tests {
471 use super::*;
472 use indoc::indoc;
473
474 #[test]
475 fn test_cursor_excerpt_with_caret() {
476 let mut spec = ExampleSpec {
477 name: String::new(),
478 repository_url: String::new(),
479 revision: String::new(),
480 tags: Vec::new(),
481 reasoning: None,
482 uncommitted_diff: String::new(),
483 cursor_path: Path::new("test.rs").into(),
484 cursor_position: String::new(),
485 edit_history: String::new(),
486 expected_patches: Vec::new(),
487 rejected_patch: None,
488 captured_prompt_input: None,
489 };
490
491 // Cursor before `42`
492 let excerpt = indoc! {"
493 fn main() {
494 let x = 42;
495 println!(\"{}\", x);
496 }"
497 };
498 let offset = excerpt.find("42").unwrap();
499 let position_string = indoc! {"
500 fn main() {
501 let x = 42;
502 // ^[CURSOR_POSITION]
503 println!(\"{}\", x);
504 }"
505 }
506 .to_string();
507
508 spec.set_cursor_excerpt(excerpt, offset, "//");
509 assert_eq!(spec.cursor_position, position_string);
510 assert_eq!(
511 spec.cursor_excerpt().unwrap(),
512 (excerpt.to_string(), offset)
513 );
514
515 // Cursor after `l` in `let`
516 let offset = excerpt.find("et x").unwrap();
517 let position_string = indoc! {"
518 fn main() {
519 let x = 42;
520 // ^[CURSOR_POSITION]
521 println!(\"{}\", x);
522 }"
523 }
524 .to_string();
525
526 spec.set_cursor_excerpt(excerpt, offset, "//");
527 assert_eq!(spec.cursor_position, position_string);
528 assert_eq!(
529 spec.cursor_excerpt().unwrap(),
530 (excerpt.to_string(), offset)
531 );
532
533 // Cursor before `let`
534 let offset = excerpt.find("let").unwrap();
535 let position_string = indoc! {"
536 fn main() {
537 let x = 42;
538 // ^[CURSOR_POSITION]
539 println!(\"{}\", x);
540 }"
541 }
542 .to_string();
543
544 spec.set_cursor_excerpt(excerpt, offset, "//");
545 assert_eq!(spec.cursor_position, position_string);
546 assert_eq!(
547 spec.cursor_excerpt().unwrap(),
548 (excerpt.to_string(), offset)
549 );
550
551 // Cursor at beginning of the line with `let`
552 let offset = excerpt.find(" let").unwrap();
553 let position_string = indoc! {"
554 fn main() {
555 let x = 42;
556 // <[CURSOR_POSITION]
557 println!(\"{}\", x);
558 }"
559 }
560 .to_string();
561
562 spec.set_cursor_excerpt(excerpt, offset, "//");
563 assert_eq!(spec.cursor_position, position_string);
564 assert_eq!(
565 spec.cursor_excerpt().unwrap(),
566 (excerpt.to_string(), offset)
567 );
568
569 // Cursor at end of line, after the semicolon
570 let offset = excerpt.find(';').unwrap() + 1;
571 let position_string = indoc! {"
572 fn main() {
573 let x = 42;
574 // ^[CURSOR_POSITION]
575 println!(\"{}\", x);
576 }"
577 }
578 .to_string();
579
580 spec.set_cursor_excerpt(excerpt, offset, "//");
581 assert_eq!(spec.cursor_position, position_string);
582 assert_eq!(
583 spec.cursor_excerpt().unwrap(),
584 (excerpt.to_string(), offset)
585 );
586
587 // Caret at end of file (no trailing newline)
588 let excerpt = indoc! {"
589 fn main() {
590 let x = 42;"
591 };
592 let offset = excerpt.find(';').unwrap() + 1;
593 let position_string = indoc! {"
594 fn main() {
595 let x = 42;
596 // ^[CURSOR_POSITION]"
597 }
598 .to_string();
599
600 spec.set_cursor_excerpt(excerpt, offset, "//");
601 assert_eq!(spec.cursor_position, position_string);
602 assert_eq!(
603 spec.cursor_excerpt().unwrap(),
604 (excerpt.to_string(), offset)
605 );
606 }
607
608 #[test]
609 fn test_cursor_excerpt_with_inline_marker() {
610 let mut spec = ExampleSpec {
611 name: String::new(),
612 repository_url: String::new(),
613 revision: String::new(),
614 tags: Vec::new(),
615 reasoning: None,
616 uncommitted_diff: String::new(),
617 cursor_path: Path::new("test.rs").into(),
618 cursor_position: String::new(),
619 edit_history: String::new(),
620 expected_patches: Vec::new(),
621 rejected_patch: None,
622 captured_prompt_input: None,
623 };
624
625 // Cursor before `42` using inline marker
626 spec.cursor_position = indoc! {"
627 fn main() {
628 let x = <|user_cursor|>42;
629 println!(\"{}\", x);
630 }"
631 }
632 .to_string();
633
634 let expected_excerpt = indoc! {"
635 fn main() {
636 let x = 42;
637 println!(\"{}\", x);
638 }"
639 };
640 let expected_offset = expected_excerpt.find("42").unwrap();
641
642 assert_eq!(
643 spec.cursor_excerpt().unwrap(),
644 (expected_excerpt.to_string(), expected_offset)
645 );
646
647 // Cursor at beginning of line
648 spec.cursor_position = indoc! {"
649 fn main() {
650 <|user_cursor|> let x = 42;
651 }"
652 }
653 .to_string();
654
655 let expected_excerpt = indoc! {"
656 fn main() {
657 let x = 42;
658 }"
659 };
660 let expected_offset = expected_excerpt.find(" let").unwrap();
661
662 assert_eq!(
663 spec.cursor_excerpt().unwrap(),
664 (expected_excerpt.to_string(), expected_offset)
665 );
666
667 // Cursor at end of file
668 spec.cursor_position = "fn main() {}<|user_cursor|>".to_string();
669 let expected_excerpt = "fn main() {}";
670 let expected_offset = expected_excerpt.len();
671
672 assert_eq!(
673 spec.cursor_excerpt().unwrap(),
674 (expected_excerpt.to_string(), expected_offset)
675 );
676 }
677}