1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use std::fmt::Write;
4use std::ops::Range;
5use std::path::Path;
6use std::sync::Arc;
7use strum::{EnumIter, IntoEnumIterator as _, IntoStaticStr};
8
9pub const CURSOR_MARKER: &str = "<|user_cursor|>";
10pub const MAX_PROMPT_TOKENS: usize = 4096;
11
12fn estimate_tokens(bytes: usize) -> usize {
13 bytes / 3
14}
15
16#[derive(Clone, Debug, Serialize, Deserialize)]
17pub struct ZetaPromptInput {
18 pub cursor_path: Arc<Path>,
19 pub cursor_excerpt: Arc<str>,
20 pub editable_range_in_excerpt: Range<usize>,
21 pub cursor_offset_in_excerpt: usize,
22 pub events: Vec<Arc<Event>>,
23 pub related_files: Vec<RelatedFile>,
24}
25
26#[derive(
27 Default,
28 Clone,
29 Copy,
30 Debug,
31 PartialEq,
32 Eq,
33 Hash,
34 EnumIter,
35 IntoStaticStr,
36 Serialize,
37 Deserialize,
38)]
39#[allow(non_camel_case_types)]
40pub enum ZetaVersion {
41 V0112MiddleAtEnd,
42 V0113Ordered,
43 #[default]
44 V0114180EditableRegion,
45 V0120GitMergeMarkers,
46}
47
48impl std::fmt::Display for ZetaVersion {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 write!(f, "{}", <&'static str>::from(self))
51 }
52}
53
54impl ZetaVersion {
55 pub fn parse(version_string: &str) -> Result<Self> {
56 let mut results = ZetaVersion::iter().filter(|version| {
57 <&'static str>::from(version)
58 .to_lowercase()
59 .contains(&version_string.to_lowercase())
60 });
61 let Some(result) = results.next() else {
62 anyhow::bail!(
63 "`{version_string}` did not match any of:\n{}",
64 Self::options_as_string()
65 );
66 };
67 if results.next().is_some() {
68 anyhow::bail!(
69 "`{version_string}` matched more than one of:\n{}",
70 Self::options_as_string()
71 );
72 }
73 Ok(result)
74 }
75
76 pub fn options_as_string() -> String {
77 ZetaVersion::iter()
78 .map(|version| format!("- {}\n", <&'static str>::from(version)))
79 .collect::<Vec<_>>()
80 .concat()
81 }
82}
83
84#[derive(Clone, Debug, Serialize, Deserialize)]
85#[serde(tag = "event")]
86pub enum Event {
87 BufferChange {
88 path: Arc<Path>,
89 old_path: Arc<Path>,
90 diff: String,
91 predicted: bool,
92 in_open_source_repo: bool,
93 },
94}
95
96pub fn write_event(prompt: &mut String, event: &Event) {
97 fn write_path_as_unix_str(prompt: &mut String, path: &Path) {
98 for component in path.components() {
99 prompt.push('/');
100 write!(prompt, "{}", component.as_os_str().display()).ok();
101 }
102 }
103 match event {
104 Event::BufferChange {
105 path,
106 old_path,
107 diff,
108 predicted,
109 in_open_source_repo: _,
110 } => {
111 if *predicted {
112 prompt.push_str("// User accepted prediction:\n");
113 }
114 prompt.push_str("--- a");
115 write_path_as_unix_str(prompt, old_path.as_ref());
116 prompt.push_str("\n+++ b");
117 write_path_as_unix_str(prompt, path.as_ref());
118 prompt.push('\n');
119 prompt.push_str(diff);
120 }
121 }
122}
123
124#[derive(Clone, Debug, Serialize, Deserialize)]
125pub struct RelatedFile {
126 pub path: Arc<Path>,
127 pub max_row: u32,
128 pub excerpts: Vec<RelatedExcerpt>,
129}
130
131#[derive(Clone, Debug, Serialize, Deserialize)]
132pub struct RelatedExcerpt {
133 pub row_range: Range<u32>,
134 pub text: Arc<str>,
135}
136
137pub fn format_zeta_prompt(input: &ZetaPromptInput, version: ZetaVersion) -> String {
138 format_zeta_prompt_with_budget(input, version, MAX_PROMPT_TOKENS)
139}
140
141fn format_zeta_prompt_with_budget(
142 input: &ZetaPromptInput,
143 version: ZetaVersion,
144 max_tokens: usize,
145) -> String {
146 let mut cursor_section = String::new();
147 match version {
148 ZetaVersion::V0112MiddleAtEnd => {
149 v0112_middle_at_end::write_cursor_excerpt_section(&mut cursor_section, input);
150 }
151 ZetaVersion::V0113Ordered | ZetaVersion::V0114180EditableRegion => {
152 v0113_ordered::write_cursor_excerpt_section(&mut cursor_section, input)
153 }
154 ZetaVersion::V0120GitMergeMarkers => {
155 v0120_git_merge_markers::write_cursor_excerpt_section(&mut cursor_section, input)
156 }
157 }
158
159 let cursor_tokens = estimate_tokens(cursor_section.len());
160 let budget_after_cursor = max_tokens.saturating_sub(cursor_tokens);
161
162 let edit_history_section =
163 format_edit_history_within_budget(&input.events, budget_after_cursor);
164 let edit_history_tokens = estimate_tokens(edit_history_section.len());
165 let budget_after_edit_history = budget_after_cursor.saturating_sub(edit_history_tokens);
166
167 let related_files_section =
168 format_related_files_within_budget(&input.related_files, budget_after_edit_history);
169
170 let mut prompt = String::new();
171 prompt.push_str(&related_files_section);
172 prompt.push_str(&edit_history_section);
173 prompt.push_str(&cursor_section);
174 prompt
175}
176
177fn format_edit_history_within_budget(events: &[Arc<Event>], max_tokens: usize) -> String {
178 let header = "<|file_sep|>edit history\n";
179 let header_tokens = estimate_tokens(header.len());
180 if header_tokens >= max_tokens {
181 return String::new();
182 }
183
184 let mut event_strings: Vec<String> = Vec::new();
185 let mut total_tokens = header_tokens;
186
187 for event in events.iter().rev() {
188 let mut event_str = String::new();
189 write_event(&mut event_str, event);
190 let event_tokens = estimate_tokens(event_str.len());
191
192 if total_tokens + event_tokens > max_tokens {
193 break;
194 }
195 total_tokens += event_tokens;
196 event_strings.push(event_str);
197 }
198
199 if event_strings.is_empty() {
200 return String::new();
201 }
202
203 let mut result = String::from(header);
204 for event_str in event_strings.iter().rev() {
205 result.push_str(&event_str);
206 }
207 result
208}
209
210fn format_related_files_within_budget(related_files: &[RelatedFile], max_tokens: usize) -> String {
211 let mut result = String::new();
212 let mut total_tokens = 0;
213
214 for file in related_files {
215 let path_str = file.path.to_string_lossy();
216 let header_len = "<|file_sep|>".len() + path_str.len() + 1;
217 let header_tokens = estimate_tokens(header_len);
218
219 if total_tokens + header_tokens > max_tokens {
220 break;
221 }
222
223 let mut file_tokens = header_tokens;
224 let mut excerpts_to_include = 0;
225
226 for excerpt in &file.excerpts {
227 let needs_newline = !excerpt.text.ends_with('\n');
228 let needs_ellipsis = excerpt.row_range.end < file.max_row;
229 let excerpt_len = excerpt.text.len()
230 + if needs_newline { "\n".len() } else { "".len() }
231 + if needs_ellipsis {
232 "...\n".len()
233 } else {
234 "".len()
235 };
236
237 let excerpt_tokens = estimate_tokens(excerpt_len);
238 if total_tokens + file_tokens + excerpt_tokens > max_tokens {
239 break;
240 }
241 file_tokens += excerpt_tokens;
242 excerpts_to_include += 1;
243 }
244
245 if excerpts_to_include > 0 {
246 total_tokens += file_tokens;
247 write!(result, "<|file_sep|>{}\n", path_str).ok();
248 for excerpt in file.excerpts.iter().take(excerpts_to_include) {
249 result.push_str(&excerpt.text);
250 if !result.ends_with('\n') {
251 result.push('\n');
252 }
253 if excerpt.row_range.end < file.max_row {
254 result.push_str("...\n");
255 }
256 }
257 }
258 }
259
260 result
261}
262
263pub fn write_related_files(
264 prompt: &mut String,
265 related_files: &[RelatedFile],
266) -> Vec<Range<usize>> {
267 let mut ranges = Vec::new();
268 for file in related_files {
269 let start = prompt.len();
270 let path_str = file.path.to_string_lossy();
271 write!(prompt, "<|file_sep|>{}\n", path_str).ok();
272 for excerpt in &file.excerpts {
273 prompt.push_str(&excerpt.text);
274 if !prompt.ends_with('\n') {
275 prompt.push('\n');
276 }
277 if excerpt.row_range.end < file.max_row {
278 prompt.push_str("...\n");
279 }
280 }
281 let end = prompt.len();
282 ranges.push(start..end);
283 }
284 ranges
285}
286
287mod v0112_middle_at_end {
288 use super::*;
289
290 pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
291 let path_str = input.cursor_path.to_string_lossy();
292 write!(prompt, "<|file_sep|>{}\n", path_str).ok();
293
294 prompt.push_str("<|fim_prefix|>\n");
295 prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
296
297 prompt.push_str("<|fim_suffix|>\n");
298 prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
299 if !prompt.ends_with('\n') {
300 prompt.push('\n');
301 }
302
303 prompt.push_str("<|fim_middle|>current\n");
304 prompt.push_str(
305 &input.cursor_excerpt
306 [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
307 );
308 prompt.push_str(CURSOR_MARKER);
309 prompt.push_str(
310 &input.cursor_excerpt
311 [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
312 );
313 if !prompt.ends_with('\n') {
314 prompt.push('\n');
315 }
316
317 prompt.push_str("<|fim_middle|>updated\n");
318 }
319}
320
321mod v0113_ordered {
322 use super::*;
323
324 pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
325 let path_str = input.cursor_path.to_string_lossy();
326 write!(prompt, "<|file_sep|>{}\n", path_str).ok();
327
328 prompt.push_str("<|fim_prefix|>\n");
329 prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
330 if !prompt.ends_with('\n') {
331 prompt.push('\n');
332 }
333
334 prompt.push_str("<|fim_middle|>current\n");
335 prompt.push_str(
336 &input.cursor_excerpt
337 [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
338 );
339 prompt.push_str(CURSOR_MARKER);
340 prompt.push_str(
341 &input.cursor_excerpt
342 [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
343 );
344 if !prompt.ends_with('\n') {
345 prompt.push('\n');
346 }
347
348 prompt.push_str("<|fim_suffix|>\n");
349 prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
350 if !prompt.ends_with('\n') {
351 prompt.push('\n');
352 }
353
354 prompt.push_str("<|fim_middle|>updated\n");
355 }
356}
357
358pub mod v0120_git_merge_markers {
359 //! A prompt that uses git-style merge conflict markers to represent the editable region.
360 //!
361 //! Example prompt:
362 //!
363 //! <|file_sep|>path/to/target_file.py
364 //! <|fim_prefix|>
365 //! code before editable region
366 //! <|fim_suffix|>
367 //! code after editable region
368 //! <|fim_middle|>
369 //! <<<<<<< CURRENT
370 //! code that
371 //! needs to<|user_cursor|>
372 //! be rewritten
373 //! =======
374 //!
375 //! Expected output (should be generated by the model):
376 //!
377 //! updated
378 //! code with
379 //! changes applied
380 //! >>>>>>> UPDATED
381
382 use super::*;
383
384 pub const START_MARKER: &str = "<<<<<<< CURRENT\n";
385 pub const SEPARATOR: &str = "=======\n";
386 pub const END_MARKER: &str = ">>>>>>> UPDATED\n";
387
388 pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
389 let path_str = input.cursor_path.to_string_lossy();
390 write!(prompt, "<|file_sep|>{}\n", path_str).ok();
391
392 prompt.push_str("<|fim_prefix|>");
393 prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
394
395 prompt.push_str("<|fim_suffix|>");
396 prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
397 if !prompt.ends_with('\n') {
398 prompt.push('\n');
399 }
400
401 prompt.push_str("<|fim_middle|>");
402 prompt.push_str(START_MARKER);
403 prompt.push_str(
404 &input.cursor_excerpt
405 [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
406 );
407 prompt.push_str(CURSOR_MARKER);
408 prompt.push_str(
409 &input.cursor_excerpt
410 [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
411 );
412 if !prompt.ends_with('\n') {
413 prompt.push('\n');
414 }
415 prompt.push_str(SEPARATOR);
416 }
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422 use indoc::indoc;
423
424 fn make_input(
425 cursor_excerpt: &str,
426 editable_range: Range<usize>,
427 cursor_offset: usize,
428 events: Vec<Event>,
429 related_files: Vec<RelatedFile>,
430 ) -> ZetaPromptInput {
431 ZetaPromptInput {
432 cursor_path: Path::new("test.rs").into(),
433 cursor_excerpt: cursor_excerpt.into(),
434 editable_range_in_excerpt: editable_range,
435 cursor_offset_in_excerpt: cursor_offset,
436 events: events.into_iter().map(Arc::new).collect(),
437 related_files,
438 }
439 }
440
441 fn make_event(path: &str, diff: &str) -> Event {
442 Event::BufferChange {
443 path: Path::new(path).into(),
444 old_path: Path::new(path).into(),
445 diff: diff.to_string(),
446 predicted: false,
447 in_open_source_repo: false,
448 }
449 }
450
451 fn make_related_file(path: &str, content: &str) -> RelatedFile {
452 RelatedFile {
453 path: Path::new(path).into(),
454 max_row: content.lines().count() as u32,
455 excerpts: vec![RelatedExcerpt {
456 row_range: 0..content.lines().count() as u32,
457 text: content.into(),
458 }],
459 }
460 }
461
462 fn format_with_budget(input: &ZetaPromptInput, max_tokens: usize) -> String {
463 format_zeta_prompt_with_budget(input, ZetaVersion::V0114180EditableRegion, max_tokens)
464 }
465
466 #[test]
467 fn test_no_truncation_when_within_budget() {
468 let input = make_input(
469 "prefix\neditable\nsuffix",
470 7..15,
471 10,
472 vec![make_event("a.rs", "-old\n+new\n")],
473 vec![make_related_file("related.rs", "fn helper() {}\n")],
474 );
475
476 assert_eq!(
477 format_with_budget(&input, 10000),
478 indoc! {r#"
479 <|file_sep|>related.rs
480 fn helper() {}
481 <|file_sep|>edit history
482 --- a/a.rs
483 +++ b/a.rs
484 -old
485 +new
486 <|file_sep|>test.rs
487 <|fim_prefix|>
488 prefix
489 <|fim_middle|>current
490 edi<|user_cursor|>table
491 <|fim_suffix|>
492
493 suffix
494 <|fim_middle|>updated
495 "#}
496 );
497 }
498
499 #[test]
500 fn test_truncation_drops_edit_history_when_budget_tight() {
501 let input = make_input(
502 "code",
503 0..4,
504 2,
505 vec![make_event("a.rs", "-x\n+y\n")],
506 vec![
507 make_related_file("r1.rs", "a\n"),
508 make_related_file("r2.rs", "b\n"),
509 ],
510 );
511
512 assert_eq!(
513 format_with_budget(&input, 10000),
514 indoc! {r#"
515 <|file_sep|>r1.rs
516 a
517 <|file_sep|>r2.rs
518 b
519 <|file_sep|>edit history
520 --- a/a.rs
521 +++ b/a.rs
522 -x
523 +y
524 <|file_sep|>test.rs
525 <|fim_prefix|>
526 <|fim_middle|>current
527 co<|user_cursor|>de
528 <|fim_suffix|>
529 <|fim_middle|>updated
530 "#}
531 );
532
533 assert_eq!(
534 format_with_budget(&input, 50),
535 indoc! {r#"
536 <|file_sep|>r1.rs
537 a
538 <|file_sep|>r2.rs
539 b
540 <|file_sep|>test.rs
541 <|fim_prefix|>
542 <|fim_middle|>current
543 co<|user_cursor|>de
544 <|fim_suffix|>
545 <|fim_middle|>updated
546 "#}
547 );
548 }
549
550 #[test]
551 fn test_truncation_includes_partial_excerpts() {
552 let input = make_input(
553 "x",
554 0..1,
555 0,
556 vec![],
557 vec![RelatedFile {
558 path: Path::new("big.rs").into(),
559 max_row: 30,
560 excerpts: vec![
561 RelatedExcerpt {
562 row_range: 0..10,
563 text: "first excerpt\n".into(),
564 },
565 RelatedExcerpt {
566 row_range: 10..20,
567 text: "second excerpt\n".into(),
568 },
569 RelatedExcerpt {
570 row_range: 20..30,
571 text: "third excerpt\n".into(),
572 },
573 ],
574 }],
575 );
576
577 assert_eq!(
578 format_with_budget(&input, 10000),
579 indoc! {r#"
580 <|file_sep|>big.rs
581 first excerpt
582 ...
583 second excerpt
584 ...
585 third excerpt
586 <|file_sep|>test.rs
587 <|fim_prefix|>
588 <|fim_middle|>current
589 <|user_cursor|>x
590 <|fim_suffix|>
591 <|fim_middle|>updated
592 "#}
593 );
594
595 assert_eq!(
596 format_with_budget(&input, 50),
597 indoc! {r#"
598 <|file_sep|>big.rs
599 first excerpt
600 ...
601 <|file_sep|>test.rs
602 <|fim_prefix|>
603 <|fim_middle|>current
604 <|user_cursor|>x
605 <|fim_suffix|>
606 <|fim_middle|>updated
607 "#}
608 );
609 }
610
611 #[test]
612 fn test_truncation_drops_older_events_first() {
613 let input = make_input(
614 "x",
615 0..1,
616 0,
617 vec![make_event("old.rs", "-1\n"), make_event("new.rs", "-2\n")],
618 vec![],
619 );
620
621 assert_eq!(
622 format_with_budget(&input, 10000),
623 indoc! {r#"
624 <|file_sep|>edit history
625 --- a/old.rs
626 +++ b/old.rs
627 -1
628 --- a/new.rs
629 +++ b/new.rs
630 -2
631 <|file_sep|>test.rs
632 <|fim_prefix|>
633 <|fim_middle|>current
634 <|user_cursor|>x
635 <|fim_suffix|>
636 <|fim_middle|>updated
637 "#}
638 );
639
640 assert_eq!(
641 format_with_budget(&input, 55),
642 indoc! {r#"
643 <|file_sep|>edit history
644 --- a/new.rs
645 +++ b/new.rs
646 -2
647 <|file_sep|>test.rs
648 <|fim_prefix|>
649 <|fim_middle|>current
650 <|user_cursor|>x
651 <|fim_suffix|>
652 <|fim_middle|>updated
653 "#}
654 );
655 }
656
657 #[test]
658 fn test_cursor_excerpt_always_included_with_minimal_budget() {
659 let input = make_input(
660 "fn main() {}",
661 0..12,
662 3,
663 vec![make_event("a.rs", "-old\n+new\n")],
664 vec![make_related_file("related.rs", "helper\n")],
665 );
666
667 assert_eq!(
668 format_with_budget(&input, 30),
669 indoc! {r#"
670 <|file_sep|>test.rs
671 <|fim_prefix|>
672 <|fim_middle|>current
673 fn <|user_cursor|>main() {}
674 <|fim_suffix|>
675 <|fim_middle|>updated
676 "#}
677 );
678 }
679}