1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use std::fmt::Write;
4use std::ops::Range;
5use std::path::Path;
6use std::sync::Arc;
7use strum::{EnumIter, IntoEnumIterator as _, IntoStaticStr};
8
9pub const CURSOR_MARKER: &str = "<|user_cursor|>";
10pub const MAX_PROMPT_TOKENS: usize = 4096;
11
12fn estimate_tokens(bytes: usize) -> usize {
13 bytes / 3
14}
15
16#[derive(Clone, Debug, Serialize, Deserialize)]
17pub struct ZetaPromptInput {
18 pub cursor_path: Arc<Path>,
19 pub cursor_excerpt: Arc<str>,
20 pub editable_range_in_excerpt: Range<usize>,
21 pub cursor_offset_in_excerpt: usize,
22 pub events: Vec<Arc<Event>>,
23 pub related_files: Vec<RelatedFile>,
24}
25
26#[derive(
27 Default,
28 Clone,
29 Copy,
30 Debug,
31 PartialEq,
32 Eq,
33 Hash,
34 EnumIter,
35 IntoStaticStr,
36 Serialize,
37 Deserialize,
38)]
39#[allow(non_camel_case_types)]
40pub enum ZetaVersion {
41 V0112MiddleAtEnd,
42 V0113Ordered,
43 #[default]
44 V0114180EditableRegion,
45 V0120GitMergeMarkers,
46}
47
48impl std::fmt::Display for ZetaVersion {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 write!(f, "{}", <&'static str>::from(self))
51 }
52}
53
54impl ZetaVersion {
55 pub fn parse(version_string: &str) -> Result<Self> {
56 let mut results = ZetaVersion::iter().filter(|version| {
57 <&'static str>::from(version)
58 .to_lowercase()
59 .contains(&version_string.to_lowercase())
60 });
61 let Some(result) = results.next() else {
62 anyhow::bail!(
63 "`{version_string}` did not match any of:\n{}",
64 Self::options_as_string()
65 );
66 };
67 if results.next().is_some() {
68 anyhow::bail!(
69 "`{version_string}` matched more than one of:\n{}",
70 Self::options_as_string()
71 );
72 }
73 Ok(result)
74 }
75
76 pub fn options_as_string() -> String {
77 ZetaVersion::iter()
78 .map(|version| format!("- {}\n", <&'static str>::from(version)))
79 .collect::<Vec<_>>()
80 .concat()
81 }
82}
83
84#[derive(Clone, Debug, Serialize, Deserialize)]
85#[serde(tag = "event")]
86pub enum Event {
87 BufferChange {
88 path: Arc<Path>,
89 old_path: Arc<Path>,
90 diff: String,
91 predicted: bool,
92 in_open_source_repo: bool,
93 },
94}
95
96pub fn write_event(prompt: &mut String, event: &Event) {
97 fn write_path_as_unix_str(prompt: &mut String, path: &Path) {
98 for component in path.components() {
99 prompt.push('/');
100 write!(prompt, "{}", component.as_os_str().display()).ok();
101 }
102 }
103 match event {
104 Event::BufferChange {
105 path,
106 old_path,
107 diff,
108 predicted,
109 in_open_source_repo: _,
110 } => {
111 if *predicted {
112 prompt.push_str("// User accepted prediction:\n");
113 }
114 prompt.push_str("--- a");
115 write_path_as_unix_str(prompt, old_path.as_ref());
116 prompt.push_str("\n+++ b");
117 write_path_as_unix_str(prompt, path.as_ref());
118 prompt.push('\n');
119 prompt.push_str(diff);
120 }
121 }
122}
123
124#[derive(Clone, Debug, Serialize, Deserialize)]
125pub struct RelatedFile {
126 pub path: Arc<Path>,
127 pub max_row: u32,
128 pub excerpts: Vec<RelatedExcerpt>,
129}
130
131#[derive(Clone, Debug, Serialize, Deserialize)]
132pub struct RelatedExcerpt {
133 pub row_range: Range<u32>,
134 pub text: Arc<str>,
135}
136
137pub fn format_zeta_prompt(input: &ZetaPromptInput, version: ZetaVersion) -> String {
138 format_zeta_prompt_with_budget(input, version, MAX_PROMPT_TOKENS)
139}
140
141fn format_zeta_prompt_with_budget(
142 input: &ZetaPromptInput,
143 version: ZetaVersion,
144 max_tokens: usize,
145) -> String {
146 let mut cursor_section = String::new();
147 match version {
148 ZetaVersion::V0112MiddleAtEnd => {
149 v0112_middle_at_end::write_cursor_excerpt_section(&mut cursor_section, input);
150 }
151 ZetaVersion::V0113Ordered | ZetaVersion::V0114180EditableRegion => {
152 v0113_ordered::write_cursor_excerpt_section(&mut cursor_section, input)
153 }
154 ZetaVersion::V0120GitMergeMarkers => {
155 v0120_git_merge_markers::write_cursor_excerpt_section(&mut cursor_section, input)
156 }
157 }
158
159 let cursor_tokens = estimate_tokens(cursor_section.len());
160 let budget_after_cursor = max_tokens.saturating_sub(cursor_tokens);
161
162 let edit_history_section =
163 format_edit_history_within_budget(&input.events, budget_after_cursor);
164 let edit_history_tokens = estimate_tokens(edit_history_section.len());
165 let budget_after_edit_history = budget_after_cursor.saturating_sub(edit_history_tokens);
166
167 let related_files_section =
168 format_related_files_within_budget(&input.related_files, budget_after_edit_history);
169
170 let mut prompt = String::new();
171 prompt.push_str(&related_files_section);
172 prompt.push_str(&edit_history_section);
173 prompt.push_str(&cursor_section);
174 prompt
175}
176
177fn format_edit_history_within_budget(events: &[Arc<Event>], max_tokens: usize) -> String {
178 let header = "<|file_sep|>edit history\n";
179 let header_tokens = estimate_tokens(header.len());
180 if header_tokens >= max_tokens {
181 return String::new();
182 }
183
184 let mut event_strings: Vec<String> = Vec::new();
185 let mut total_tokens = header_tokens;
186
187 for event in events.iter().rev() {
188 let mut event_str = String::new();
189 write_event(&mut event_str, event);
190 let event_tokens = estimate_tokens(event_str.len());
191
192 if total_tokens + event_tokens > max_tokens {
193 break;
194 }
195 total_tokens += event_tokens;
196 event_strings.push(event_str);
197 }
198
199 if event_strings.is_empty() {
200 return String::new();
201 }
202
203 let mut result = String::from(header);
204 for event_str in event_strings.iter().rev() {
205 result.push_str(&event_str);
206 }
207 result
208}
209
210fn format_related_files_within_budget(related_files: &[RelatedFile], max_tokens: usize) -> String {
211 let mut result = String::new();
212 let mut total_tokens = 0;
213
214 for file in related_files {
215 let path_str = file.path.to_string_lossy();
216 let header_len = "<|file_sep|>".len() + path_str.len() + 1;
217 let header_tokens = estimate_tokens(header_len);
218
219 if total_tokens + header_tokens > max_tokens {
220 break;
221 }
222
223 let mut file_tokens = header_tokens;
224 let mut excerpts_to_include = 0;
225
226 for excerpt in &file.excerpts {
227 let needs_newline = !excerpt.text.ends_with('\n');
228 let needs_ellipsis = excerpt.row_range.end < file.max_row;
229 let excerpt_len = excerpt.text.len()
230 + if needs_newline { "\n".len() } else { "".len() }
231 + if needs_ellipsis {
232 "...\n".len()
233 } else {
234 "".len()
235 };
236
237 let excerpt_tokens = estimate_tokens(excerpt_len);
238 if total_tokens + file_tokens + excerpt_tokens > max_tokens {
239 break;
240 }
241 file_tokens += excerpt_tokens;
242 excerpts_to_include += 1;
243 }
244
245 if excerpts_to_include > 0 {
246 total_tokens += file_tokens;
247 write!(result, "<|file_sep|>{}\n", path_str).ok();
248 for excerpt in file.excerpts.iter().take(excerpts_to_include) {
249 result.push_str(&excerpt.text);
250 if !result.ends_with('\n') {
251 result.push('\n');
252 }
253 if excerpt.row_range.end < file.max_row {
254 result.push_str("...\n");
255 }
256 }
257 }
258 }
259
260 result
261}
262
263pub fn write_related_files(
264 prompt: &mut String,
265 related_files: &[RelatedFile],
266) -> Vec<Range<usize>> {
267 let mut ranges = Vec::new();
268 for file in related_files {
269 let start = prompt.len();
270 let path_str = file.path.to_string_lossy();
271 write!(prompt, "<|file_sep|>{}\n", path_str).ok();
272 for excerpt in &file.excerpts {
273 prompt.push_str(&excerpt.text);
274 if !prompt.ends_with('\n') {
275 prompt.push('\n');
276 }
277 if excerpt.row_range.end < file.max_row {
278 prompt.push_str("...\n");
279 }
280 }
281 let end = prompt.len();
282 ranges.push(start..end);
283 }
284 ranges
285}
286
287mod v0112_middle_at_end {
288 use super::*;
289
290 pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
291 let path_str = input.cursor_path.to_string_lossy();
292 write!(prompt, "<|file_sep|>{}\n", path_str).ok();
293
294 prompt.push_str("<|fim_prefix|>\n");
295 prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
296
297 prompt.push_str("<|fim_suffix|>\n");
298 prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
299 if !prompt.ends_with('\n') {
300 prompt.push('\n');
301 }
302
303 prompt.push_str("<|fim_middle|>current\n");
304 prompt.push_str(
305 &input.cursor_excerpt
306 [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
307 );
308 prompt.push_str(CURSOR_MARKER);
309 prompt.push_str(
310 &input.cursor_excerpt
311 [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
312 );
313 if !prompt.ends_with('\n') {
314 prompt.push('\n');
315 }
316
317 prompt.push_str("<|fim_middle|>updated\n");
318 }
319}
320
321mod v0113_ordered {
322 use super::*;
323
324 pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
325 let path_str = input.cursor_path.to_string_lossy();
326 write!(prompt, "<|file_sep|>{}\n", path_str).ok();
327
328 prompt.push_str("<|fim_prefix|>\n");
329 prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
330 if !prompt.ends_with('\n') {
331 prompt.push('\n');
332 }
333
334 prompt.push_str("<|fim_middle|>current\n");
335 prompt.push_str(
336 &input.cursor_excerpt
337 [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
338 );
339 prompt.push_str(CURSOR_MARKER);
340 prompt.push_str(
341 &input.cursor_excerpt
342 [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
343 );
344 if !prompt.ends_with('\n') {
345 prompt.push('\n');
346 }
347
348 prompt.push_str("<|fim_suffix|>\n");
349 prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
350 if !prompt.ends_with('\n') {
351 prompt.push('\n');
352 }
353
354 prompt.push_str("<|fim_middle|>updated\n");
355 }
356}
357
358pub mod v0120_git_merge_markers {
359 //! A prompt that uses git-style merge conflict markers to represent the editable region.
360 //!
361 //! Example prompt:
362 //!
363 //! <|file_sep|>path/to/target_file.py
364 //! <|fim_prefix|>
365 //! code before editable region
366 //! <|fim_suffix|>
367 //! code after editable region
368 //! <|fim_middle|>
369 //! <<<<<<< CURRENT
370 //! code that
371 //! needs to<|user_cursor|>
372 //! be rewritten
373 //! =======
374 //!
375 //! Expected output (should be generated by the model):
376 //!
377 //! updated
378 //! code with
379 //! changes applied
380 //! >>>>>>> UPDATED
381
382 use super::*;
383
384 pub const START_MARKER: &str = "<<<<<<< CURRENT\n";
385 pub const SEPARATOR: &str = "=======\n";
386 pub const END_MARKER: &str = ">>>>>>> UPDATED\n";
387
388 pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
389 let path_str = input.cursor_path.to_string_lossy();
390 write!(prompt, "<|file_sep|>{}\n", path_str).ok();
391
392 prompt.push_str("<|fim_prefix|>");
393 prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
394
395 prompt.push_str("<|fim_suffix|>");
396 prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
397 if !prompt.ends_with('\n') {
398 prompt.push('\n');
399 }
400
401 prompt.push_str("<|fim_middle|>");
402 prompt.push_str(START_MARKER);
403 prompt.push_str(
404 &input.cursor_excerpt
405 [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
406 );
407 prompt.push_str(CURSOR_MARKER);
408 prompt.push_str(
409 &input.cursor_excerpt
410 [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
411 );
412 if !prompt.ends_with('\n') {
413 prompt.push('\n');
414 }
415 prompt.push_str(SEPARATOR);
416 }
417}
418
419/// The zeta1 prompt format
420pub mod zeta1 {
421 pub const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
422 pub const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
423 pub const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
424 pub const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
425
426 const INSTRUCTION_HEADER: &str = concat!(
427 "### Instruction:\n",
428 "You are a code completion assistant and your task is to analyze user edits and then rewrite an ",
429 "excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking ",
430 "into account the cursor location.\n\n",
431 "### User Edits:\n\n"
432 );
433 const EXCERPT_HEADER: &str = "\n\n### User Excerpt:\n\n";
434 const RESPONSE_HEADER: &str = "\n\n### Response:\n";
435
436 /// Formats a complete zeta1 prompt from the input events and excerpt.
437 pub fn format_zeta1_prompt(input_events: &str, input_excerpt: &str) -> String {
438 let mut prompt = String::with_capacity(
439 INSTRUCTION_HEADER.len()
440 + input_events.len()
441 + EXCERPT_HEADER.len()
442 + input_excerpt.len()
443 + RESPONSE_HEADER.len(),
444 );
445 prompt.push_str(INSTRUCTION_HEADER);
446 prompt.push_str(input_events);
447 prompt.push_str(EXCERPT_HEADER);
448 prompt.push_str(input_excerpt);
449 prompt.push_str(RESPONSE_HEADER);
450 prompt
451 }
452}
453
454#[cfg(test)]
455mod tests {
456 use super::*;
457 use indoc::indoc;
458
459 fn make_input(
460 cursor_excerpt: &str,
461 editable_range: Range<usize>,
462 cursor_offset: usize,
463 events: Vec<Event>,
464 related_files: Vec<RelatedFile>,
465 ) -> ZetaPromptInput {
466 ZetaPromptInput {
467 cursor_path: Path::new("test.rs").into(),
468 cursor_excerpt: cursor_excerpt.into(),
469 editable_range_in_excerpt: editable_range,
470 cursor_offset_in_excerpt: cursor_offset,
471 events: events.into_iter().map(Arc::new).collect(),
472 related_files,
473 }
474 }
475
476 fn make_event(path: &str, diff: &str) -> Event {
477 Event::BufferChange {
478 path: Path::new(path).into(),
479 old_path: Path::new(path).into(),
480 diff: diff.to_string(),
481 predicted: false,
482 in_open_source_repo: false,
483 }
484 }
485
486 fn make_related_file(path: &str, content: &str) -> RelatedFile {
487 RelatedFile {
488 path: Path::new(path).into(),
489 max_row: content.lines().count() as u32,
490 excerpts: vec![RelatedExcerpt {
491 row_range: 0..content.lines().count() as u32,
492 text: content.into(),
493 }],
494 }
495 }
496
497 fn format_with_budget(input: &ZetaPromptInput, max_tokens: usize) -> String {
498 format_zeta_prompt_with_budget(input, ZetaVersion::V0114180EditableRegion, max_tokens)
499 }
500
501 #[test]
502 fn test_no_truncation_when_within_budget() {
503 let input = make_input(
504 "prefix\neditable\nsuffix",
505 7..15,
506 10,
507 vec![make_event("a.rs", "-old\n+new\n")],
508 vec![make_related_file("related.rs", "fn helper() {}\n")],
509 );
510
511 assert_eq!(
512 format_with_budget(&input, 10000),
513 indoc! {r#"
514 <|file_sep|>related.rs
515 fn helper() {}
516 <|file_sep|>edit history
517 --- a/a.rs
518 +++ b/a.rs
519 -old
520 +new
521 <|file_sep|>test.rs
522 <|fim_prefix|>
523 prefix
524 <|fim_middle|>current
525 edi<|user_cursor|>table
526 <|fim_suffix|>
527
528 suffix
529 <|fim_middle|>updated
530 "#}
531 );
532 }
533
534 #[test]
535 fn test_truncation_drops_edit_history_when_budget_tight() {
536 let input = make_input(
537 "code",
538 0..4,
539 2,
540 vec![make_event("a.rs", "-x\n+y\n")],
541 vec![
542 make_related_file("r1.rs", "a\n"),
543 make_related_file("r2.rs", "b\n"),
544 ],
545 );
546
547 assert_eq!(
548 format_with_budget(&input, 10000),
549 indoc! {r#"
550 <|file_sep|>r1.rs
551 a
552 <|file_sep|>r2.rs
553 b
554 <|file_sep|>edit history
555 --- a/a.rs
556 +++ b/a.rs
557 -x
558 +y
559 <|file_sep|>test.rs
560 <|fim_prefix|>
561 <|fim_middle|>current
562 co<|user_cursor|>de
563 <|fim_suffix|>
564 <|fim_middle|>updated
565 "#}
566 );
567
568 assert_eq!(
569 format_with_budget(&input, 50),
570 indoc! {r#"
571 <|file_sep|>r1.rs
572 a
573 <|file_sep|>r2.rs
574 b
575 <|file_sep|>test.rs
576 <|fim_prefix|>
577 <|fim_middle|>current
578 co<|user_cursor|>de
579 <|fim_suffix|>
580 <|fim_middle|>updated
581 "#}
582 );
583 }
584
585 #[test]
586 fn test_truncation_includes_partial_excerpts() {
587 let input = make_input(
588 "x",
589 0..1,
590 0,
591 vec![],
592 vec![RelatedFile {
593 path: Path::new("big.rs").into(),
594 max_row: 30,
595 excerpts: vec![
596 RelatedExcerpt {
597 row_range: 0..10,
598 text: "first excerpt\n".into(),
599 },
600 RelatedExcerpt {
601 row_range: 10..20,
602 text: "second excerpt\n".into(),
603 },
604 RelatedExcerpt {
605 row_range: 20..30,
606 text: "third excerpt\n".into(),
607 },
608 ],
609 }],
610 );
611
612 assert_eq!(
613 format_with_budget(&input, 10000),
614 indoc! {r#"
615 <|file_sep|>big.rs
616 first excerpt
617 ...
618 second excerpt
619 ...
620 third excerpt
621 <|file_sep|>test.rs
622 <|fim_prefix|>
623 <|fim_middle|>current
624 <|user_cursor|>x
625 <|fim_suffix|>
626 <|fim_middle|>updated
627 "#}
628 );
629
630 assert_eq!(
631 format_with_budget(&input, 50),
632 indoc! {r#"
633 <|file_sep|>big.rs
634 first excerpt
635 ...
636 <|file_sep|>test.rs
637 <|fim_prefix|>
638 <|fim_middle|>current
639 <|user_cursor|>x
640 <|fim_suffix|>
641 <|fim_middle|>updated
642 "#}
643 );
644 }
645
646 #[test]
647 fn test_truncation_drops_older_events_first() {
648 let input = make_input(
649 "x",
650 0..1,
651 0,
652 vec![make_event("old.rs", "-1\n"), make_event("new.rs", "-2\n")],
653 vec![],
654 );
655
656 assert_eq!(
657 format_with_budget(&input, 10000),
658 indoc! {r#"
659 <|file_sep|>edit history
660 --- a/old.rs
661 +++ b/old.rs
662 -1
663 --- a/new.rs
664 +++ b/new.rs
665 -2
666 <|file_sep|>test.rs
667 <|fim_prefix|>
668 <|fim_middle|>current
669 <|user_cursor|>x
670 <|fim_suffix|>
671 <|fim_middle|>updated
672 "#}
673 );
674
675 assert_eq!(
676 format_with_budget(&input, 55),
677 indoc! {r#"
678 <|file_sep|>edit history
679 --- a/new.rs
680 +++ b/new.rs
681 -2
682 <|file_sep|>test.rs
683 <|fim_prefix|>
684 <|fim_middle|>current
685 <|user_cursor|>x
686 <|fim_suffix|>
687 <|fim_middle|>updated
688 "#}
689 );
690 }
691
692 #[test]
693 fn test_cursor_excerpt_always_included_with_minimal_budget() {
694 let input = make_input(
695 "fn main() {}",
696 0..12,
697 3,
698 vec![make_event("a.rs", "-old\n+new\n")],
699 vec![make_related_file("related.rs", "helper\n")],
700 );
701
702 assert_eq!(
703 format_with_budget(&input, 30),
704 indoc! {r#"
705 <|file_sep|>test.rs
706 <|fim_prefix|>
707 <|fim_middle|>current
708 fn <|user_cursor|>main() {}
709 <|fim_suffix|>
710 <|fim_middle|>updated
711 "#}
712 );
713 }
714}