zeta_prompt.rs

  1use anyhow::Result;
  2use serde::{Deserialize, Serialize};
  3use std::fmt::Write;
  4use std::ops::Range;
  5use std::path::Path;
  6use std::sync::Arc;
  7use strum::{EnumIter, IntoEnumIterator as _, IntoStaticStr};
  8
  9pub const CURSOR_MARKER: &str = "<|user_cursor|>";
 10
 11#[derive(Clone, Debug, Serialize, Deserialize)]
 12pub struct ZetaPromptInput {
 13    pub cursor_path: Arc<Path>,
 14    pub cursor_excerpt: Arc<str>,
 15    pub editable_range_in_excerpt: Range<usize>,
 16    pub cursor_offset_in_excerpt: usize,
 17    pub events: Vec<Arc<Event>>,
 18    pub related_files: Vec<RelatedFile>,
 19}
 20
 21#[derive(
 22    Default,
 23    Clone,
 24    Copy,
 25    Debug,
 26    PartialEq,
 27    Eq,
 28    Hash,
 29    EnumIter,
 30    IntoStaticStr,
 31    Serialize,
 32    Deserialize,
 33)]
 34#[allow(non_camel_case_types)]
 35pub enum ZetaVersion {
 36    V0112MiddleAtEnd,
 37    V0113Ordered,
 38    #[default]
 39    V0114180EditableRegion,
 40    V0120GitMergeMarkers,
 41}
 42
 43impl std::fmt::Display for ZetaVersion {
 44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 45        write!(f, "{}", <&'static str>::from(self))
 46    }
 47}
 48
 49impl ZetaVersion {
 50    pub fn parse(version_string: &str) -> Result<Self> {
 51        let mut results = ZetaVersion::iter().filter(|version| {
 52            <&'static str>::from(version)
 53                .to_lowercase()
 54                .contains(&version_string.to_lowercase())
 55        });
 56        let Some(result) = results.next() else {
 57            anyhow::bail!(
 58                "`{version_string}` did not match any of:\n{}",
 59                Self::options_as_string()
 60            );
 61        };
 62        if results.next().is_some() {
 63            anyhow::bail!(
 64                "`{version_string}` matched more than one of:\n{}",
 65                Self::options_as_string()
 66            );
 67        }
 68        Ok(result)
 69    }
 70
 71    pub fn options_as_string() -> String {
 72        ZetaVersion::iter()
 73            .map(|version| format!("- {}\n", <&'static str>::from(version)))
 74            .collect::<Vec<_>>()
 75            .concat()
 76    }
 77}
 78
 79#[derive(Clone, Debug, Serialize, Deserialize)]
 80#[serde(tag = "event")]
 81pub enum Event {
 82    BufferChange {
 83        path: Arc<Path>,
 84        old_path: Arc<Path>,
 85        diff: String,
 86        predicted: bool,
 87        in_open_source_repo: bool,
 88    },
 89}
 90
 91pub fn write_event(prompt: &mut String, event: &Event) {
 92    fn write_path_as_unix_str(prompt: &mut String, path: &Path) {
 93        for component in path.components() {
 94            prompt.push('/');
 95            write!(prompt, "{}", component.as_os_str().display()).ok();
 96        }
 97    }
 98    match event {
 99        Event::BufferChange {
100            path,
101            old_path,
102            diff,
103            predicted,
104            in_open_source_repo: _,
105        } => {
106            if *predicted {
107                prompt.push_str("// User accepted prediction:\n");
108            }
109            prompt.push_str("--- a");
110            write_path_as_unix_str(prompt, old_path.as_ref());
111            prompt.push_str("\n+++ b");
112            write_path_as_unix_str(prompt, path.as_ref());
113            prompt.push('\n');
114            prompt.push_str(diff);
115        }
116    }
117}
118
119#[derive(Clone, Debug, Serialize, Deserialize)]
120pub struct RelatedFile {
121    pub path: Arc<Path>,
122    pub max_row: u32,
123    pub excerpts: Vec<RelatedExcerpt>,
124}
125
126#[derive(Clone, Debug, Serialize, Deserialize)]
127pub struct RelatedExcerpt {
128    pub row_range: Range<u32>,
129    pub text: Arc<str>,
130}
131
132pub fn format_zeta_prompt(input: &ZetaPromptInput, version: ZetaVersion) -> String {
133    let mut prompt = String::new();
134    write_related_files(&mut prompt, &input.related_files);
135    write_edit_history_section(&mut prompt, input);
136
137    match version {
138        ZetaVersion::V0112MiddleAtEnd => {
139            v0112_middle_at_end::write_cursor_excerpt_section(&mut prompt, input);
140        }
141        ZetaVersion::V0113Ordered | ZetaVersion::V0114180EditableRegion => {
142            v0113_ordered::write_cursor_excerpt_section(&mut prompt, input)
143        }
144
145        ZetaVersion::V0120GitMergeMarkers => {
146            v0120_git_merge_markers::write_cursor_excerpt_section(&mut prompt, input)
147        }
148    }
149
150    prompt
151}
152
153pub fn write_related_files(prompt: &mut String, related_files: &[RelatedFile]) {
154    for file in related_files {
155        let path_str = file.path.to_string_lossy();
156        write!(prompt, "<|file_sep|>{}\n", path_str).ok();
157        for excerpt in &file.excerpts {
158            prompt.push_str(&excerpt.text);
159            if !prompt.ends_with('\n') {
160                prompt.push('\n');
161            }
162            if excerpt.row_range.end < file.max_row {
163                prompt.push_str("...\n");
164            }
165        }
166    }
167}
168
169fn write_edit_history_section(prompt: &mut String, input: &ZetaPromptInput) {
170    prompt.push_str("<|file_sep|>edit history\n");
171    for event in &input.events {
172        write_event(prompt, event);
173    }
174}
175
176mod v0112_middle_at_end {
177    use super::*;
178
179    pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
180        let path_str = input.cursor_path.to_string_lossy();
181        write!(prompt, "<|file_sep|>{}\n", path_str).ok();
182
183        prompt.push_str("<|fim_prefix|>\n");
184        prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
185
186        prompt.push_str("<|fim_suffix|>\n");
187        prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
188        if !prompt.ends_with('\n') {
189            prompt.push('\n');
190        }
191
192        prompt.push_str("<|fim_middle|>current\n");
193        prompt.push_str(
194            &input.cursor_excerpt
195                [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
196        );
197        prompt.push_str(CURSOR_MARKER);
198        prompt.push_str(
199            &input.cursor_excerpt
200                [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
201        );
202        if !prompt.ends_with('\n') {
203            prompt.push('\n');
204        }
205
206        prompt.push_str("<|fim_middle|>updated\n");
207    }
208}
209
210mod v0113_ordered {
211    use super::*;
212
213    pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
214        let path_str = input.cursor_path.to_string_lossy();
215        write!(prompt, "<|file_sep|>{}\n", path_str).ok();
216
217        prompt.push_str("<|fim_prefix|>\n");
218        prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
219        if !prompt.ends_with('\n') {
220            prompt.push('\n');
221        }
222
223        prompt.push_str("<|fim_middle|>current\n");
224        prompt.push_str(
225            &input.cursor_excerpt
226                [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
227        );
228        prompt.push_str(CURSOR_MARKER);
229        prompt.push_str(
230            &input.cursor_excerpt
231                [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
232        );
233        if !prompt.ends_with('\n') {
234            prompt.push('\n');
235        }
236
237        prompt.push_str("<|fim_suffix|>\n");
238        prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
239        if !prompt.ends_with('\n') {
240            prompt.push('\n');
241        }
242
243        prompt.push_str("<|fim_middle|>updated\n");
244    }
245}
246
247pub mod v0120_git_merge_markers {
248    //! A prompt that uses git-style merge conflict markers to represent the editable region.
249    //!
250    //! Example prompt:
251    //!
252    //! <|file_sep|>path/to/target_file.py
253    //! <|fim_prefix|>
254    //! code before editable region
255    //! <|fim_suffix|>
256    //! code after editable region
257    //! <|fim_middle|>
258    //! <<<<<<< CURRENT
259    //! code that
260    //! needs to<|user_cursor|>
261    //! be rewritten
262    //! =======
263    //!
264    //! Expected output (should be generated by the model):
265    //!
266    //! updated
267    //! code with
268    //! changes applied
269    //! >>>>>>> UPDATED
270
271    use super::*;
272
273    pub const START_MARKER: &str = "<<<<<<< CURRENT\n";
274    pub const SEPARATOR: &str = "=======\n";
275    pub const END_MARKER: &str = ">>>>>>> UPDATED\n";
276
277    pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
278        let path_str = input.cursor_path.to_string_lossy();
279        write!(prompt, "<|file_sep|>{}\n", path_str).ok();
280
281        prompt.push_str("<|fim_prefix|>");
282        prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
283
284        prompt.push_str("<|fim_suffix|>");
285        prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
286        if !prompt.ends_with('\n') {
287            prompt.push('\n');
288        }
289
290        prompt.push_str("<|fim_middle|>");
291        prompt.push_str(START_MARKER);
292        prompt.push_str(
293            &input.cursor_excerpt
294                [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
295        );
296        prompt.push_str(CURSOR_MARKER);
297        prompt.push_str(
298            &input.cursor_excerpt
299                [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
300        );
301        if !prompt.ends_with('\n') {
302            prompt.push('\n');
303        }
304        prompt.push_str(SEPARATOR);
305    }
306}