zeta_prompt.rs

  1use anyhow::Result;
  2use serde::{Deserialize, Serialize};
  3use std::fmt::Write;
  4use std::ops::Range;
  5use std::path::Path;
  6use std::sync::Arc;
  7use strum::{EnumIter, IntoEnumIterator as _, IntoStaticStr};
  8
  9pub const CURSOR_MARKER: &str = "<|user_cursor|>";
 10pub const MAX_PROMPT_TOKENS: usize = 4096;
 11
 12fn estimate_tokens(bytes: usize) -> usize {
 13    bytes / 3
 14}
 15
 16#[derive(Clone, Debug, Serialize, Deserialize)]
 17pub struct ZetaPromptInput {
 18    pub cursor_path: Arc<Path>,
 19    pub cursor_excerpt: Arc<str>,
 20    pub editable_range_in_excerpt: Range<usize>,
 21    pub cursor_offset_in_excerpt: usize,
 22    pub events: Vec<Arc<Event>>,
 23    pub related_files: Vec<RelatedFile>,
 24}
 25
 26#[derive(
 27    Default,
 28    Clone,
 29    Copy,
 30    Debug,
 31    PartialEq,
 32    Eq,
 33    Hash,
 34    EnumIter,
 35    IntoStaticStr,
 36    Serialize,
 37    Deserialize,
 38)]
 39#[allow(non_camel_case_types)]
 40pub enum ZetaVersion {
 41    V0112MiddleAtEnd,
 42    V0113Ordered,
 43    #[default]
 44    V0114180EditableRegion,
 45    V0120GitMergeMarkers,
 46}
 47
 48impl std::fmt::Display for ZetaVersion {
 49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 50        write!(f, "{}", <&'static str>::from(self))
 51    }
 52}
 53
 54impl ZetaVersion {
 55    pub fn parse(version_string: &str) -> Result<Self> {
 56        let mut results = ZetaVersion::iter().filter(|version| {
 57            <&'static str>::from(version)
 58                .to_lowercase()
 59                .contains(&version_string.to_lowercase())
 60        });
 61        let Some(result) = results.next() else {
 62            anyhow::bail!(
 63                "`{version_string}` did not match any of:\n{}",
 64                Self::options_as_string()
 65            );
 66        };
 67        if results.next().is_some() {
 68            anyhow::bail!(
 69                "`{version_string}` matched more than one of:\n{}",
 70                Self::options_as_string()
 71            );
 72        }
 73        Ok(result)
 74    }
 75
 76    pub fn options_as_string() -> String {
 77        ZetaVersion::iter()
 78            .map(|version| format!("- {}\n", <&'static str>::from(version)))
 79            .collect::<Vec<_>>()
 80            .concat()
 81    }
 82}
 83
 84#[derive(Clone, Debug, Serialize, Deserialize)]
 85#[serde(tag = "event")]
 86pub enum Event {
 87    BufferChange {
 88        path: Arc<Path>,
 89        old_path: Arc<Path>,
 90        diff: String,
 91        predicted: bool,
 92        in_open_source_repo: bool,
 93    },
 94}
 95
 96pub fn write_event(prompt: &mut String, event: &Event) {
 97    fn write_path_as_unix_str(prompt: &mut String, path: &Path) {
 98        for component in path.components() {
 99            prompt.push('/');
100            write!(prompt, "{}", component.as_os_str().display()).ok();
101        }
102    }
103    match event {
104        Event::BufferChange {
105            path,
106            old_path,
107            diff,
108            predicted,
109            in_open_source_repo: _,
110        } => {
111            if *predicted {
112                prompt.push_str("// User accepted prediction:\n");
113            }
114            prompt.push_str("--- a");
115            write_path_as_unix_str(prompt, old_path.as_ref());
116            prompt.push_str("\n+++ b");
117            write_path_as_unix_str(prompt, path.as_ref());
118            prompt.push('\n');
119            prompt.push_str(diff);
120        }
121    }
122}
123
124#[derive(Clone, Debug, Serialize, Deserialize)]
125pub struct RelatedFile {
126    pub path: Arc<Path>,
127    pub max_row: u32,
128    pub excerpts: Vec<RelatedExcerpt>,
129}
130
131#[derive(Clone, Debug, Serialize, Deserialize)]
132pub struct RelatedExcerpt {
133    pub row_range: Range<u32>,
134    pub text: Arc<str>,
135}
136
137pub fn format_zeta_prompt(input: &ZetaPromptInput, version: ZetaVersion) -> String {
138    format_zeta_prompt_with_budget(input, version, MAX_PROMPT_TOKENS)
139}
140
141fn format_zeta_prompt_with_budget(
142    input: &ZetaPromptInput,
143    version: ZetaVersion,
144    max_tokens: usize,
145) -> String {
146    let mut cursor_section = String::new();
147    match version {
148        ZetaVersion::V0112MiddleAtEnd => {
149            v0112_middle_at_end::write_cursor_excerpt_section(&mut cursor_section, input);
150        }
151        ZetaVersion::V0113Ordered | ZetaVersion::V0114180EditableRegion => {
152            v0113_ordered::write_cursor_excerpt_section(&mut cursor_section, input)
153        }
154        ZetaVersion::V0120GitMergeMarkers => {
155            v0120_git_merge_markers::write_cursor_excerpt_section(&mut cursor_section, input)
156        }
157    }
158
159    let cursor_tokens = estimate_tokens(cursor_section.len());
160    let budget_after_cursor = max_tokens.saturating_sub(cursor_tokens);
161
162    let edit_history_section =
163        format_edit_history_within_budget(&input.events, budget_after_cursor);
164    let edit_history_tokens = estimate_tokens(edit_history_section.len());
165    let budget_after_edit_history = budget_after_cursor.saturating_sub(edit_history_tokens);
166
167    let related_files_section =
168        format_related_files_within_budget(&input.related_files, budget_after_edit_history);
169
170    let mut prompt = String::new();
171    prompt.push_str(&related_files_section);
172    prompt.push_str(&edit_history_section);
173    prompt.push_str(&cursor_section);
174    prompt
175}
176
177fn format_edit_history_within_budget(events: &[Arc<Event>], max_tokens: usize) -> String {
178    let header = "<|file_sep|>edit history\n";
179    let header_tokens = estimate_tokens(header.len());
180    if header_tokens >= max_tokens {
181        return String::new();
182    }
183
184    let mut event_strings: Vec<String> = Vec::new();
185    let mut total_tokens = header_tokens;
186
187    for event in events.iter().rev() {
188        let mut event_str = String::new();
189        write_event(&mut event_str, event);
190        let event_tokens = estimate_tokens(event_str.len());
191
192        if total_tokens + event_tokens > max_tokens {
193            break;
194        }
195        total_tokens += event_tokens;
196        event_strings.push(event_str);
197    }
198
199    if event_strings.is_empty() {
200        return String::new();
201    }
202
203    let mut result = String::from(header);
204    for event_str in event_strings.iter().rev() {
205        result.push_str(&event_str);
206    }
207    result
208}
209
210fn format_related_files_within_budget(related_files: &[RelatedFile], max_tokens: usize) -> String {
211    let mut result = String::new();
212    let mut total_tokens = 0;
213
214    for file in related_files {
215        let path_str = file.path.to_string_lossy();
216        let header_len = "<|file_sep|>".len() + path_str.len() + 1;
217        let header_tokens = estimate_tokens(header_len);
218
219        if total_tokens + header_tokens > max_tokens {
220            break;
221        }
222
223        let mut file_tokens = header_tokens;
224        let mut excerpts_to_include = 0;
225
226        for excerpt in &file.excerpts {
227            let needs_newline = !excerpt.text.ends_with('\n');
228            let needs_ellipsis = excerpt.row_range.end < file.max_row;
229            let excerpt_len = excerpt.text.len()
230                + if needs_newline { "\n".len() } else { "".len() }
231                + if needs_ellipsis {
232                    "...\n".len()
233                } else {
234                    "".len()
235                };
236
237            let excerpt_tokens = estimate_tokens(excerpt_len);
238            if total_tokens + file_tokens + excerpt_tokens > max_tokens {
239                break;
240            }
241            file_tokens += excerpt_tokens;
242            excerpts_to_include += 1;
243        }
244
245        if excerpts_to_include > 0 {
246            total_tokens += file_tokens;
247            write!(result, "<|file_sep|>{}\n", path_str).ok();
248            for excerpt in file.excerpts.iter().take(excerpts_to_include) {
249                result.push_str(&excerpt.text);
250                if !result.ends_with('\n') {
251                    result.push('\n');
252                }
253                if excerpt.row_range.end < file.max_row {
254                    result.push_str("...\n");
255                }
256            }
257        }
258    }
259
260    result
261}
262
263pub fn write_related_files(
264    prompt: &mut String,
265    related_files: &[RelatedFile],
266) -> Vec<Range<usize>> {
267    let mut ranges = Vec::new();
268    for file in related_files {
269        let start = prompt.len();
270        let path_str = file.path.to_string_lossy();
271        write!(prompt, "<|file_sep|>{}\n", path_str).ok();
272        for excerpt in &file.excerpts {
273            prompt.push_str(&excerpt.text);
274            if !prompt.ends_with('\n') {
275                prompt.push('\n');
276            }
277            if excerpt.row_range.end < file.max_row {
278                prompt.push_str("...\n");
279            }
280        }
281        let end = prompt.len();
282        ranges.push(start..end);
283    }
284    ranges
285}
286
287mod v0112_middle_at_end {
288    use super::*;
289
290    pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
291        let path_str = input.cursor_path.to_string_lossy();
292        write!(prompt, "<|file_sep|>{}\n", path_str).ok();
293
294        prompt.push_str("<|fim_prefix|>\n");
295        prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
296
297        prompt.push_str("<|fim_suffix|>\n");
298        prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
299        if !prompt.ends_with('\n') {
300            prompt.push('\n');
301        }
302
303        prompt.push_str("<|fim_middle|>current\n");
304        prompt.push_str(
305            &input.cursor_excerpt
306                [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
307        );
308        prompt.push_str(CURSOR_MARKER);
309        prompt.push_str(
310            &input.cursor_excerpt
311                [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
312        );
313        if !prompt.ends_with('\n') {
314            prompt.push('\n');
315        }
316
317        prompt.push_str("<|fim_middle|>updated\n");
318    }
319}
320
321mod v0113_ordered {
322    use super::*;
323
324    pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
325        let path_str = input.cursor_path.to_string_lossy();
326        write!(prompt, "<|file_sep|>{}\n", path_str).ok();
327
328        prompt.push_str("<|fim_prefix|>\n");
329        prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
330        if !prompt.ends_with('\n') {
331            prompt.push('\n');
332        }
333
334        prompt.push_str("<|fim_middle|>current\n");
335        prompt.push_str(
336            &input.cursor_excerpt
337                [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
338        );
339        prompt.push_str(CURSOR_MARKER);
340        prompt.push_str(
341            &input.cursor_excerpt
342                [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
343        );
344        if !prompt.ends_with('\n') {
345            prompt.push('\n');
346        }
347
348        prompt.push_str("<|fim_suffix|>\n");
349        prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
350        if !prompt.ends_with('\n') {
351            prompt.push('\n');
352        }
353
354        prompt.push_str("<|fim_middle|>updated\n");
355    }
356}
357
358pub mod v0120_git_merge_markers {
359    //! A prompt that uses git-style merge conflict markers to represent the editable region.
360    //!
361    //! Example prompt:
362    //!
363    //! <|file_sep|>path/to/target_file.py
364    //! <|fim_prefix|>
365    //! code before editable region
366    //! <|fim_suffix|>
367    //! code after editable region
368    //! <|fim_middle|>
369    //! <<<<<<< CURRENT
370    //! code that
371    //! needs to<|user_cursor|>
372    //! be rewritten
373    //! =======
374    //!
375    //! Expected output (should be generated by the model):
376    //!
377    //! updated
378    //! code with
379    //! changes applied
380    //! >>>>>>> UPDATED
381
382    use super::*;
383
384    pub const START_MARKER: &str = "<<<<<<< CURRENT\n";
385    pub const SEPARATOR: &str = "=======\n";
386    pub const END_MARKER: &str = ">>>>>>> UPDATED\n";
387
388    pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
389        let path_str = input.cursor_path.to_string_lossy();
390        write!(prompt, "<|file_sep|>{}\n", path_str).ok();
391
392        prompt.push_str("<|fim_prefix|>");
393        prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
394
395        prompt.push_str("<|fim_suffix|>");
396        prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
397        if !prompt.ends_with('\n') {
398            prompt.push('\n');
399        }
400
401        prompt.push_str("<|fim_middle|>");
402        prompt.push_str(START_MARKER);
403        prompt.push_str(
404            &input.cursor_excerpt
405                [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
406        );
407        prompt.push_str(CURSOR_MARKER);
408        prompt.push_str(
409            &input.cursor_excerpt
410                [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
411        );
412        if !prompt.ends_with('\n') {
413            prompt.push('\n');
414        }
415        prompt.push_str(SEPARATOR);
416    }
417}
418
419/// The zeta1 prompt format
420pub mod zeta1 {
421    pub const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
422    pub const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
423    pub const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
424    pub const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
425
426    const INSTRUCTION_HEADER: &str = concat!(
427        "### Instruction:\n",
428        "You are a code completion assistant and your task is to analyze user edits and then rewrite an ",
429        "excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking ",
430        "into account the cursor location.\n\n",
431        "### User Edits:\n\n"
432    );
433    const EXCERPT_HEADER: &str = "\n\n### User Excerpt:\n\n";
434    const RESPONSE_HEADER: &str = "\n\n### Response:\n";
435
436    /// Formats a complete zeta1 prompt from the input events and excerpt.
437    pub fn format_zeta1_prompt(input_events: &str, input_excerpt: &str) -> String {
438        let mut prompt = String::with_capacity(
439            INSTRUCTION_HEADER.len()
440                + input_events.len()
441                + EXCERPT_HEADER.len()
442                + input_excerpt.len()
443                + RESPONSE_HEADER.len(),
444        );
445        prompt.push_str(INSTRUCTION_HEADER);
446        prompt.push_str(input_events);
447        prompt.push_str(EXCERPT_HEADER);
448        prompt.push_str(input_excerpt);
449        prompt.push_str(RESPONSE_HEADER);
450        prompt
451    }
452}
453
454#[cfg(test)]
455mod tests {
456    use super::*;
457    use indoc::indoc;
458
459    fn make_input(
460        cursor_excerpt: &str,
461        editable_range: Range<usize>,
462        cursor_offset: usize,
463        events: Vec<Event>,
464        related_files: Vec<RelatedFile>,
465    ) -> ZetaPromptInput {
466        ZetaPromptInput {
467            cursor_path: Path::new("test.rs").into(),
468            cursor_excerpt: cursor_excerpt.into(),
469            editable_range_in_excerpt: editable_range,
470            cursor_offset_in_excerpt: cursor_offset,
471            events: events.into_iter().map(Arc::new).collect(),
472            related_files,
473        }
474    }
475
476    fn make_event(path: &str, diff: &str) -> Event {
477        Event::BufferChange {
478            path: Path::new(path).into(),
479            old_path: Path::new(path).into(),
480            diff: diff.to_string(),
481            predicted: false,
482            in_open_source_repo: false,
483        }
484    }
485
486    fn make_related_file(path: &str, content: &str) -> RelatedFile {
487        RelatedFile {
488            path: Path::new(path).into(),
489            max_row: content.lines().count() as u32,
490            excerpts: vec![RelatedExcerpt {
491                row_range: 0..content.lines().count() as u32,
492                text: content.into(),
493            }],
494        }
495    }
496
497    fn format_with_budget(input: &ZetaPromptInput, max_tokens: usize) -> String {
498        format_zeta_prompt_with_budget(input, ZetaVersion::V0114180EditableRegion, max_tokens)
499    }
500
501    #[test]
502    fn test_no_truncation_when_within_budget() {
503        let input = make_input(
504            "prefix\neditable\nsuffix",
505            7..15,
506            10,
507            vec![make_event("a.rs", "-old\n+new\n")],
508            vec![make_related_file("related.rs", "fn helper() {}\n")],
509        );
510
511        assert_eq!(
512            format_with_budget(&input, 10000),
513            indoc! {r#"
514                <|file_sep|>related.rs
515                fn helper() {}
516                <|file_sep|>edit history
517                --- a/a.rs
518                +++ b/a.rs
519                -old
520                +new
521                <|file_sep|>test.rs
522                <|fim_prefix|>
523                prefix
524                <|fim_middle|>current
525                edi<|user_cursor|>table
526                <|fim_suffix|>
527
528                suffix
529                <|fim_middle|>updated
530            "#}
531        );
532    }
533
534    #[test]
535    fn test_truncation_drops_edit_history_when_budget_tight() {
536        let input = make_input(
537            "code",
538            0..4,
539            2,
540            vec![make_event("a.rs", "-x\n+y\n")],
541            vec![
542                make_related_file("r1.rs", "a\n"),
543                make_related_file("r2.rs", "b\n"),
544            ],
545        );
546
547        assert_eq!(
548            format_with_budget(&input, 10000),
549            indoc! {r#"
550                <|file_sep|>r1.rs
551                a
552                <|file_sep|>r2.rs
553                b
554                <|file_sep|>edit history
555                --- a/a.rs
556                +++ b/a.rs
557                -x
558                +y
559                <|file_sep|>test.rs
560                <|fim_prefix|>
561                <|fim_middle|>current
562                co<|user_cursor|>de
563                <|fim_suffix|>
564                <|fim_middle|>updated
565            "#}
566        );
567
568        assert_eq!(
569            format_with_budget(&input, 50),
570            indoc! {r#"
571                <|file_sep|>r1.rs
572                a
573                <|file_sep|>r2.rs
574                b
575                <|file_sep|>test.rs
576                <|fim_prefix|>
577                <|fim_middle|>current
578                co<|user_cursor|>de
579                <|fim_suffix|>
580                <|fim_middle|>updated
581            "#}
582        );
583    }
584
585    #[test]
586    fn test_truncation_includes_partial_excerpts() {
587        let input = make_input(
588            "x",
589            0..1,
590            0,
591            vec![],
592            vec![RelatedFile {
593                path: Path::new("big.rs").into(),
594                max_row: 30,
595                excerpts: vec![
596                    RelatedExcerpt {
597                        row_range: 0..10,
598                        text: "first excerpt\n".into(),
599                    },
600                    RelatedExcerpt {
601                        row_range: 10..20,
602                        text: "second excerpt\n".into(),
603                    },
604                    RelatedExcerpt {
605                        row_range: 20..30,
606                        text: "third excerpt\n".into(),
607                    },
608                ],
609            }],
610        );
611
612        assert_eq!(
613            format_with_budget(&input, 10000),
614            indoc! {r#"
615                <|file_sep|>big.rs
616                first excerpt
617                ...
618                second excerpt
619                ...
620                third excerpt
621                <|file_sep|>test.rs
622                <|fim_prefix|>
623                <|fim_middle|>current
624                <|user_cursor|>x
625                <|fim_suffix|>
626                <|fim_middle|>updated
627            "#}
628        );
629
630        assert_eq!(
631            format_with_budget(&input, 50),
632            indoc! {r#"
633                <|file_sep|>big.rs
634                first excerpt
635                ...
636                <|file_sep|>test.rs
637                <|fim_prefix|>
638                <|fim_middle|>current
639                <|user_cursor|>x
640                <|fim_suffix|>
641                <|fim_middle|>updated
642            "#}
643        );
644    }
645
646    #[test]
647    fn test_truncation_drops_older_events_first() {
648        let input = make_input(
649            "x",
650            0..1,
651            0,
652            vec![make_event("old.rs", "-1\n"), make_event("new.rs", "-2\n")],
653            vec![],
654        );
655
656        assert_eq!(
657            format_with_budget(&input, 10000),
658            indoc! {r#"
659                <|file_sep|>edit history
660                --- a/old.rs
661                +++ b/old.rs
662                -1
663                --- a/new.rs
664                +++ b/new.rs
665                -2
666                <|file_sep|>test.rs
667                <|fim_prefix|>
668                <|fim_middle|>current
669                <|user_cursor|>x
670                <|fim_suffix|>
671                <|fim_middle|>updated
672            "#}
673        );
674
675        assert_eq!(
676            format_with_budget(&input, 55),
677            indoc! {r#"
678                <|file_sep|>edit history
679                --- a/new.rs
680                +++ b/new.rs
681                -2
682                <|file_sep|>test.rs
683                <|fim_prefix|>
684                <|fim_middle|>current
685                <|user_cursor|>x
686                <|fim_suffix|>
687                <|fim_middle|>updated
688            "#}
689        );
690    }
691
692    #[test]
693    fn test_cursor_excerpt_always_included_with_minimal_budget() {
694        let input = make_input(
695            "fn main() {}",
696            0..12,
697            3,
698            vec![make_event("a.rs", "-old\n+new\n")],
699            vec![make_related_file("related.rs", "helper\n")],
700        );
701
702        assert_eq!(
703            format_with_budget(&input, 30),
704            indoc! {r#"
705                <|file_sep|>test.rs
706                <|fim_prefix|>
707                <|fim_middle|>current
708                fn <|user_cursor|>main() {}
709                <|fim_suffix|>
710                <|fim_middle|>updated
711            "#}
712        );
713    }
714}