zeta_prompt.rs

  1use anyhow::Result;
  2use serde::{Deserialize, Serialize};
  3use std::fmt::Write;
  4use std::ops::Range;
  5use std::path::Path;
  6use std::sync::Arc;
  7use strum::{EnumIter, IntoEnumIterator as _, IntoStaticStr};
  8
  9pub const CURSOR_MARKER: &str = "<|user_cursor|>";
 10pub const MAX_PROMPT_TOKENS: usize = 4096;
 11
 12fn estimate_tokens(bytes: usize) -> usize {
 13    bytes / 3
 14}
 15
 16#[derive(Clone, Debug, Serialize, Deserialize)]
 17pub struct ZetaPromptInput {
 18    pub cursor_path: Arc<Path>,
 19    pub cursor_excerpt: Arc<str>,
 20    pub editable_range_in_excerpt: Range<usize>,
 21    pub cursor_offset_in_excerpt: usize,
 22    pub events: Vec<Arc<Event>>,
 23    pub related_files: Vec<RelatedFile>,
 24}
 25
 26#[derive(
 27    Default,
 28    Clone,
 29    Copy,
 30    Debug,
 31    PartialEq,
 32    Eq,
 33    Hash,
 34    EnumIter,
 35    IntoStaticStr,
 36    Serialize,
 37    Deserialize,
 38)]
 39#[allow(non_camel_case_types)]
 40pub enum ZetaVersion {
 41    V0112MiddleAtEnd,
 42    V0113Ordered,
 43    #[default]
 44    V0114180EditableRegion,
 45    V0120GitMergeMarkers,
 46}
 47
 48impl std::fmt::Display for ZetaVersion {
 49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 50        write!(f, "{}", <&'static str>::from(self))
 51    }
 52}
 53
 54impl ZetaVersion {
 55    pub fn parse(version_string: &str) -> Result<Self> {
 56        let mut results = ZetaVersion::iter().filter(|version| {
 57            <&'static str>::from(version)
 58                .to_lowercase()
 59                .contains(&version_string.to_lowercase())
 60        });
 61        let Some(result) = results.next() else {
 62            anyhow::bail!(
 63                "`{version_string}` did not match any of:\n{}",
 64                Self::options_as_string()
 65            );
 66        };
 67        if results.next().is_some() {
 68            anyhow::bail!(
 69                "`{version_string}` matched more than one of:\n{}",
 70                Self::options_as_string()
 71            );
 72        }
 73        Ok(result)
 74    }
 75
 76    pub fn options_as_string() -> String {
 77        ZetaVersion::iter()
 78            .map(|version| format!("- {}\n", <&'static str>::from(version)))
 79            .collect::<Vec<_>>()
 80            .concat()
 81    }
 82}
 83
 84#[derive(Clone, Debug, Serialize, Deserialize)]
 85#[serde(tag = "event")]
 86pub enum Event {
 87    BufferChange {
 88        path: Arc<Path>,
 89        old_path: Arc<Path>,
 90        diff: String,
 91        predicted: bool,
 92        in_open_source_repo: bool,
 93    },
 94}
 95
 96pub fn write_event(prompt: &mut String, event: &Event) {
 97    fn write_path_as_unix_str(prompt: &mut String, path: &Path) {
 98        for component in path.components() {
 99            prompt.push('/');
100            write!(prompt, "{}", component.as_os_str().display()).ok();
101        }
102    }
103    match event {
104        Event::BufferChange {
105            path,
106            old_path,
107            diff,
108            predicted,
109            in_open_source_repo: _,
110        } => {
111            if *predicted {
112                prompt.push_str("// User accepted prediction:\n");
113            }
114            prompt.push_str("--- a");
115            write_path_as_unix_str(prompt, old_path.as_ref());
116            prompt.push_str("\n+++ b");
117            write_path_as_unix_str(prompt, path.as_ref());
118            prompt.push('\n');
119            prompt.push_str(diff);
120        }
121    }
122}
123
124#[derive(Clone, Debug, Serialize, Deserialize)]
125pub struct RelatedFile {
126    pub path: Arc<Path>,
127    pub max_row: u32,
128    pub excerpts: Vec<RelatedExcerpt>,
129}
130
131#[derive(Clone, Debug, Serialize, Deserialize)]
132pub struct RelatedExcerpt {
133    pub row_range: Range<u32>,
134    pub text: Arc<str>,
135}
136
137pub fn format_zeta_prompt(input: &ZetaPromptInput, version: ZetaVersion) -> String {
138    format_zeta_prompt_with_budget(input, version, MAX_PROMPT_TOKENS)
139}
140
141fn format_zeta_prompt_with_budget(
142    input: &ZetaPromptInput,
143    version: ZetaVersion,
144    max_tokens: usize,
145) -> String {
146    let mut cursor_section = String::new();
147    match version {
148        ZetaVersion::V0112MiddleAtEnd => {
149            v0112_middle_at_end::write_cursor_excerpt_section(&mut cursor_section, input);
150        }
151        ZetaVersion::V0113Ordered | ZetaVersion::V0114180EditableRegion => {
152            v0113_ordered::write_cursor_excerpt_section(&mut cursor_section, input)
153        }
154        ZetaVersion::V0120GitMergeMarkers => {
155            v0120_git_merge_markers::write_cursor_excerpt_section(&mut cursor_section, input)
156        }
157    }
158
159    let cursor_tokens = estimate_tokens(cursor_section.len());
160    let budget_after_cursor = max_tokens.saturating_sub(cursor_tokens);
161
162    let edit_history_section =
163        format_edit_history_within_budget(&input.events, budget_after_cursor);
164    let edit_history_tokens = estimate_tokens(edit_history_section.len());
165    let budget_after_edit_history = budget_after_cursor.saturating_sub(edit_history_tokens);
166
167    let related_files_section =
168        format_related_files_within_budget(&input.related_files, budget_after_edit_history);
169
170    let mut prompt = String::new();
171    prompt.push_str(&related_files_section);
172    prompt.push_str(&edit_history_section);
173    prompt.push_str(&cursor_section);
174    prompt
175}
176
177fn format_edit_history_within_budget(events: &[Arc<Event>], max_tokens: usize) -> String {
178    let header = "<|file_sep|>edit history\n";
179    let header_tokens = estimate_tokens(header.len());
180    if header_tokens >= max_tokens {
181        return String::new();
182    }
183
184    let mut event_strings: Vec<String> = Vec::new();
185    let mut total_tokens = header_tokens;
186
187    for event in events.iter().rev() {
188        let mut event_str = String::new();
189        write_event(&mut event_str, event);
190        let event_tokens = estimate_tokens(event_str.len());
191
192        if total_tokens + event_tokens > max_tokens {
193            break;
194        }
195        total_tokens += event_tokens;
196        event_strings.push(event_str);
197    }
198
199    if event_strings.is_empty() {
200        return String::new();
201    }
202
203    let mut result = String::from(header);
204    for event_str in event_strings.iter().rev() {
205        result.push_str(&event_str);
206    }
207    result
208}
209
210fn format_related_files_within_budget(related_files: &[RelatedFile], max_tokens: usize) -> String {
211    let mut result = String::new();
212    let mut total_tokens = 0;
213
214    for file in related_files {
215        let path_str = file.path.to_string_lossy();
216        let header_len = "<|file_sep|>".len() + path_str.len() + 1;
217        let header_tokens = estimate_tokens(header_len);
218
219        if total_tokens + header_tokens > max_tokens {
220            break;
221        }
222
223        let mut file_tokens = header_tokens;
224        let mut excerpts_to_include = 0;
225
226        for excerpt in &file.excerpts {
227            let needs_newline = !excerpt.text.ends_with('\n');
228            let needs_ellipsis = excerpt.row_range.end < file.max_row;
229            let excerpt_len = excerpt.text.len()
230                + if needs_newline { "\n".len() } else { "".len() }
231                + if needs_ellipsis {
232                    "...\n".len()
233                } else {
234                    "".len()
235                };
236
237            let excerpt_tokens = estimate_tokens(excerpt_len);
238            if total_tokens + file_tokens + excerpt_tokens > max_tokens {
239                break;
240            }
241            file_tokens += excerpt_tokens;
242            excerpts_to_include += 1;
243        }
244
245        if excerpts_to_include > 0 {
246            total_tokens += file_tokens;
247            write!(result, "<|file_sep|>{}\n", path_str).ok();
248            for excerpt in file.excerpts.iter().take(excerpts_to_include) {
249                result.push_str(&excerpt.text);
250                if !result.ends_with('\n') {
251                    result.push('\n');
252                }
253                if excerpt.row_range.end < file.max_row {
254                    result.push_str("...\n");
255                }
256            }
257        }
258    }
259
260    result
261}
262
263pub fn write_related_files(
264    prompt: &mut String,
265    related_files: &[RelatedFile],
266) -> Vec<Range<usize>> {
267    let mut ranges = Vec::new();
268    for file in related_files {
269        let start = prompt.len();
270        let path_str = file.path.to_string_lossy();
271        write!(prompt, "<|file_sep|>{}\n", path_str).ok();
272        for excerpt in &file.excerpts {
273            prompt.push_str(&excerpt.text);
274            if !prompt.ends_with('\n') {
275                prompt.push('\n');
276            }
277            if excerpt.row_range.end < file.max_row {
278                prompt.push_str("...\n");
279            }
280        }
281        let end = prompt.len();
282        ranges.push(start..end);
283    }
284    ranges
285}
286
287mod v0112_middle_at_end {
288    use super::*;
289
290    pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
291        let path_str = input.cursor_path.to_string_lossy();
292        write!(prompt, "<|file_sep|>{}\n", path_str).ok();
293
294        prompt.push_str("<|fim_prefix|>\n");
295        prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
296
297        prompt.push_str("<|fim_suffix|>\n");
298        prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
299        if !prompt.ends_with('\n') {
300            prompt.push('\n');
301        }
302
303        prompt.push_str("<|fim_middle|>current\n");
304        prompt.push_str(
305            &input.cursor_excerpt
306                [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
307        );
308        prompt.push_str(CURSOR_MARKER);
309        prompt.push_str(
310            &input.cursor_excerpt
311                [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
312        );
313        if !prompt.ends_with('\n') {
314            prompt.push('\n');
315        }
316
317        prompt.push_str("<|fim_middle|>updated\n");
318    }
319}
320
321mod v0113_ordered {
322    use super::*;
323
324    pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
325        let path_str = input.cursor_path.to_string_lossy();
326        write!(prompt, "<|file_sep|>{}\n", path_str).ok();
327
328        prompt.push_str("<|fim_prefix|>\n");
329        prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
330        if !prompt.ends_with('\n') {
331            prompt.push('\n');
332        }
333
334        prompt.push_str("<|fim_middle|>current\n");
335        prompt.push_str(
336            &input.cursor_excerpt
337                [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
338        );
339        prompt.push_str(CURSOR_MARKER);
340        prompt.push_str(
341            &input.cursor_excerpt
342                [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
343        );
344        if !prompt.ends_with('\n') {
345            prompt.push('\n');
346        }
347
348        prompt.push_str("<|fim_suffix|>\n");
349        prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
350        if !prompt.ends_with('\n') {
351            prompt.push('\n');
352        }
353
354        prompt.push_str("<|fim_middle|>updated\n");
355    }
356}
357
358pub mod v0120_git_merge_markers {
359    //! A prompt that uses git-style merge conflict markers to represent the editable region.
360    //!
361    //! Example prompt:
362    //!
363    //! <|file_sep|>path/to/target_file.py
364    //! <|fim_prefix|>
365    //! code before editable region
366    //! <|fim_suffix|>
367    //! code after editable region
368    //! <|fim_middle|>
369    //! <<<<<<< CURRENT
370    //! code that
371    //! needs to<|user_cursor|>
372    //! be rewritten
373    //! =======
374    //!
375    //! Expected output (should be generated by the model):
376    //!
377    //! updated
378    //! code with
379    //! changes applied
380    //! >>>>>>> UPDATED
381
382    use super::*;
383
384    pub const START_MARKER: &str = "<<<<<<< CURRENT\n";
385    pub const SEPARATOR: &str = "=======\n";
386    pub const END_MARKER: &str = ">>>>>>> UPDATED\n";
387
388    pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
389        let path_str = input.cursor_path.to_string_lossy();
390        write!(prompt, "<|file_sep|>{}\n", path_str).ok();
391
392        prompt.push_str("<|fim_prefix|>");
393        prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
394
395        prompt.push_str("<|fim_suffix|>");
396        prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
397        if !prompt.ends_with('\n') {
398            prompt.push('\n');
399        }
400
401        prompt.push_str("<|fim_middle|>");
402        prompt.push_str(START_MARKER);
403        prompt.push_str(
404            &input.cursor_excerpt
405                [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
406        );
407        prompt.push_str(CURSOR_MARKER);
408        prompt.push_str(
409            &input.cursor_excerpt
410                [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
411        );
412        if !prompt.ends_with('\n') {
413            prompt.push('\n');
414        }
415        prompt.push_str(SEPARATOR);
416    }
417}
418
419#[cfg(test)]
420mod tests {
421    use super::*;
422    use indoc::indoc;
423
424    fn make_input(
425        cursor_excerpt: &str,
426        editable_range: Range<usize>,
427        cursor_offset: usize,
428        events: Vec<Event>,
429        related_files: Vec<RelatedFile>,
430    ) -> ZetaPromptInput {
431        ZetaPromptInput {
432            cursor_path: Path::new("test.rs").into(),
433            cursor_excerpt: cursor_excerpt.into(),
434            editable_range_in_excerpt: editable_range,
435            cursor_offset_in_excerpt: cursor_offset,
436            events: events.into_iter().map(Arc::new).collect(),
437            related_files,
438        }
439    }
440
441    fn make_event(path: &str, diff: &str) -> Event {
442        Event::BufferChange {
443            path: Path::new(path).into(),
444            old_path: Path::new(path).into(),
445            diff: diff.to_string(),
446            predicted: false,
447            in_open_source_repo: false,
448        }
449    }
450
451    fn make_related_file(path: &str, content: &str) -> RelatedFile {
452        RelatedFile {
453            path: Path::new(path).into(),
454            max_row: content.lines().count() as u32,
455            excerpts: vec![RelatedExcerpt {
456                row_range: 0..content.lines().count() as u32,
457                text: content.into(),
458            }],
459        }
460    }
461
462    fn format_with_budget(input: &ZetaPromptInput, max_tokens: usize) -> String {
463        format_zeta_prompt_with_budget(input, ZetaVersion::V0114180EditableRegion, max_tokens)
464    }
465
466    #[test]
467    fn test_no_truncation_when_within_budget() {
468        let input = make_input(
469            "prefix\neditable\nsuffix",
470            7..15,
471            10,
472            vec![make_event("a.rs", "-old\n+new\n")],
473            vec![make_related_file("related.rs", "fn helper() {}\n")],
474        );
475
476        assert_eq!(
477            format_with_budget(&input, 10000),
478            indoc! {r#"
479                <|file_sep|>related.rs
480                fn helper() {}
481                <|file_sep|>edit history
482                --- a/a.rs
483                +++ b/a.rs
484                -old
485                +new
486                <|file_sep|>test.rs
487                <|fim_prefix|>
488                prefix
489                <|fim_middle|>current
490                edi<|user_cursor|>table
491                <|fim_suffix|>
492
493                suffix
494                <|fim_middle|>updated
495            "#}
496        );
497    }
498
499    #[test]
500    fn test_truncation_drops_edit_history_when_budget_tight() {
501        let input = make_input(
502            "code",
503            0..4,
504            2,
505            vec![make_event("a.rs", "-x\n+y\n")],
506            vec![
507                make_related_file("r1.rs", "a\n"),
508                make_related_file("r2.rs", "b\n"),
509            ],
510        );
511
512        assert_eq!(
513            format_with_budget(&input, 10000),
514            indoc! {r#"
515                <|file_sep|>r1.rs
516                a
517                <|file_sep|>r2.rs
518                b
519                <|file_sep|>edit history
520                --- a/a.rs
521                +++ b/a.rs
522                -x
523                +y
524                <|file_sep|>test.rs
525                <|fim_prefix|>
526                <|fim_middle|>current
527                co<|user_cursor|>de
528                <|fim_suffix|>
529                <|fim_middle|>updated
530            "#}
531        );
532
533        assert_eq!(
534            format_with_budget(&input, 50),
535            indoc! {r#"
536                <|file_sep|>r1.rs
537                a
538                <|file_sep|>r2.rs
539                b
540                <|file_sep|>test.rs
541                <|fim_prefix|>
542                <|fim_middle|>current
543                co<|user_cursor|>de
544                <|fim_suffix|>
545                <|fim_middle|>updated
546            "#}
547        );
548    }
549
550    #[test]
551    fn test_truncation_includes_partial_excerpts() {
552        let input = make_input(
553            "x",
554            0..1,
555            0,
556            vec![],
557            vec![RelatedFile {
558                path: Path::new("big.rs").into(),
559                max_row: 30,
560                excerpts: vec![
561                    RelatedExcerpt {
562                        row_range: 0..10,
563                        text: "first excerpt\n".into(),
564                    },
565                    RelatedExcerpt {
566                        row_range: 10..20,
567                        text: "second excerpt\n".into(),
568                    },
569                    RelatedExcerpt {
570                        row_range: 20..30,
571                        text: "third excerpt\n".into(),
572                    },
573                ],
574            }],
575        );
576
577        assert_eq!(
578            format_with_budget(&input, 10000),
579            indoc! {r#"
580                <|file_sep|>big.rs
581                first excerpt
582                ...
583                second excerpt
584                ...
585                third excerpt
586                <|file_sep|>test.rs
587                <|fim_prefix|>
588                <|fim_middle|>current
589                <|user_cursor|>x
590                <|fim_suffix|>
591                <|fim_middle|>updated
592            "#}
593        );
594
595        assert_eq!(
596            format_with_budget(&input, 50),
597            indoc! {r#"
598                <|file_sep|>big.rs
599                first excerpt
600                ...
601                <|file_sep|>test.rs
602                <|fim_prefix|>
603                <|fim_middle|>current
604                <|user_cursor|>x
605                <|fim_suffix|>
606                <|fim_middle|>updated
607            "#}
608        );
609    }
610
611    #[test]
612    fn test_truncation_drops_older_events_first() {
613        let input = make_input(
614            "x",
615            0..1,
616            0,
617            vec![make_event("old.rs", "-1\n"), make_event("new.rs", "-2\n")],
618            vec![],
619        );
620
621        assert_eq!(
622            format_with_budget(&input, 10000),
623            indoc! {r#"
624                <|file_sep|>edit history
625                --- a/old.rs
626                +++ b/old.rs
627                -1
628                --- a/new.rs
629                +++ b/new.rs
630                -2
631                <|file_sep|>test.rs
632                <|fim_prefix|>
633                <|fim_middle|>current
634                <|user_cursor|>x
635                <|fim_suffix|>
636                <|fim_middle|>updated
637            "#}
638        );
639
640        assert_eq!(
641            format_with_budget(&input, 55),
642            indoc! {r#"
643                <|file_sep|>edit history
644                --- a/new.rs
645                +++ b/new.rs
646                -2
647                <|file_sep|>test.rs
648                <|fim_prefix|>
649                <|fim_middle|>current
650                <|user_cursor|>x
651                <|fim_suffix|>
652                <|fim_middle|>updated
653            "#}
654        );
655    }
656
657    #[test]
658    fn test_cursor_excerpt_always_included_with_minimal_budget() {
659        let input = make_input(
660            "fn main() {}",
661            0..12,
662            3,
663            vec![make_event("a.rs", "-old\n+new\n")],
664            vec![make_related_file("related.rs", "helper\n")],
665        );
666
667        assert_eq!(
668            format_with_budget(&input, 30),
669            indoc! {r#"
670                <|file_sep|>test.rs
671                <|fim_prefix|>
672                <|fim_middle|>current
673                fn <|user_cursor|>main() {}
674                <|fim_suffix|>
675                <|fim_middle|>updated
676            "#}
677        );
678    }
679}