zeta_prompt.rs

  1use anyhow::Result;
  2use serde::{Deserialize, Serialize};
  3use std::fmt::Write;
  4use std::ops::Range;
  5use std::path::Path;
  6use std::sync::Arc;
  7use strum::{EnumIter, IntoEnumIterator as _, IntoStaticStr};
  8
  9pub const CURSOR_MARKER: &str = "<|user_cursor|>";
 10
 11#[derive(Clone, Debug, Serialize, Deserialize)]
 12pub struct ZetaPromptInput {
 13    pub cursor_path: Arc<Path>,
 14    pub cursor_excerpt: Arc<str>,
 15    pub editable_range_in_excerpt: Range<usize>,
 16    pub cursor_offset_in_excerpt: usize,
 17    pub events: Vec<Arc<Event>>,
 18    pub related_files: Vec<RelatedFile>,
 19}
 20
 21#[derive(Default, Clone, Copy, Debug, PartialEq, Eq, EnumIter, IntoStaticStr)]
 22#[allow(non_camel_case_types)]
 23pub enum ZetaVersion {
 24    V0112_MiddleAtEnd,
 25    #[default]
 26    V0113_Ordered,
 27}
 28
 29impl std::fmt::Display for ZetaVersion {
 30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 31        write!(f, "{}", <&'static str>::from(self))
 32    }
 33}
 34
 35impl ZetaVersion {
 36    pub fn parse(version_string: &str) -> Result<Self> {
 37        let mut results = ZetaVersion::iter().filter(|version| {
 38            <&'static str>::from(version)
 39                .to_lowercase()
 40                .contains(&version_string.to_lowercase())
 41        });
 42        let Some(result) = results.next() else {
 43            anyhow::bail!(
 44                "`{version_string}` did not match any of:\n{}",
 45                Self::options_as_string()
 46            );
 47        };
 48        if results.next().is_some() {
 49            anyhow::bail!(
 50                "`{version_string}` matched more than one of:\n{}",
 51                Self::options_as_string()
 52            );
 53        }
 54        Ok(result)
 55    }
 56
 57    fn options_as_string() -> String {
 58        ZetaVersion::iter()
 59            .map(|version| format!("- {}\n", <&'static str>::from(version)))
 60            .collect::<Vec<_>>()
 61            .concat()
 62    }
 63
 64    pub fn default_as_string() -> String {
 65        <&'static str>::from(Self::default()).to_string()
 66    }
 67}
 68
 69#[derive(Clone, Debug, Serialize, Deserialize)]
 70#[serde(tag = "event")]
 71pub enum Event {
 72    BufferChange {
 73        path: Arc<Path>,
 74        old_path: Arc<Path>,
 75        diff: String,
 76        predicted: bool,
 77        in_open_source_repo: bool,
 78    },
 79}
 80
 81pub fn write_event(prompt: &mut String, event: &Event) {
 82    fn write_path_as_unix_str(prompt: &mut String, path: &Path) {
 83        for component in path.components() {
 84            prompt.push('/');
 85            write!(prompt, "{}", component.as_os_str().display()).ok();
 86        }
 87    }
 88    match event {
 89        Event::BufferChange {
 90            path,
 91            old_path,
 92            diff,
 93            predicted,
 94            in_open_source_repo: _,
 95        } => {
 96            if *predicted {
 97                prompt.push_str("// User accepted prediction:\n");
 98            }
 99            prompt.push_str("--- a");
100            write_path_as_unix_str(prompt, old_path.as_ref());
101            prompt.push_str("\n+++ b");
102            write_path_as_unix_str(prompt, path.as_ref());
103            prompt.push('\n');
104            prompt.push_str(diff);
105        }
106    }
107}
108
109#[derive(Clone, Debug, Serialize, Deserialize)]
110pub struct RelatedFile {
111    pub path: Arc<Path>,
112    pub max_row: u32,
113    pub excerpts: Vec<RelatedExcerpt>,
114}
115
116#[derive(Clone, Debug, Serialize, Deserialize)]
117pub struct RelatedExcerpt {
118    pub row_range: Range<u32>,
119    pub text: Arc<str>,
120}
121
122pub fn format_zeta_prompt(input: &ZetaPromptInput, version: ZetaVersion) -> String {
123    let mut prompt = String::new();
124    write_related_files(&mut prompt, &input.related_files);
125    write_edit_history_section(&mut prompt, input);
126
127    match version {
128        ZetaVersion::V0112_MiddleAtEnd => {
129            v0112_middle_at_end::write_cursor_excerpt_section(&mut prompt, input);
130        }
131        ZetaVersion::V0113_Ordered => {
132            v0113_ordered::write_cursor_excerpt_section(&mut prompt, input)
133        }
134    }
135
136    prompt
137}
138
139pub fn write_related_files(prompt: &mut String, related_files: &[RelatedFile]) {
140    for file in related_files {
141        let path_str = file.path.to_string_lossy();
142        write!(prompt, "<|file_sep|>{}\n", path_str).ok();
143        for excerpt in &file.excerpts {
144            prompt.push_str(&excerpt.text);
145            if !prompt.ends_with('\n') {
146                prompt.push('\n');
147            }
148            if excerpt.row_range.end < file.max_row {
149                prompt.push_str("...\n");
150            }
151        }
152    }
153}
154
155fn write_edit_history_section(prompt: &mut String, input: &ZetaPromptInput) {
156    prompt.push_str("<|file_sep|>edit history\n");
157    for event in &input.events {
158        write_event(prompt, event);
159    }
160}
161
162mod v0112_middle_at_end {
163    use super::*;
164
165    pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
166        let path_str = input.cursor_path.to_string_lossy();
167        write!(prompt, "<|file_sep|>{}\n", path_str).ok();
168
169        prompt.push_str("<|fim_prefix|>\n");
170        prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
171
172        prompt.push_str("<|fim_suffix|>\n");
173        prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
174        if !prompt.ends_with('\n') {
175            prompt.push('\n');
176        }
177
178        prompt.push_str("<|fim_middle|>current\n");
179        prompt.push_str(
180            &input.cursor_excerpt
181                [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
182        );
183        prompt.push_str(CURSOR_MARKER);
184        prompt.push_str(
185            &input.cursor_excerpt
186                [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
187        );
188        if !prompt.ends_with('\n') {
189            prompt.push('\n');
190        }
191
192        prompt.push_str("<|fim_middle|>updated\n");
193    }
194}
195
196mod v0113_ordered {
197    use super::*;
198
199    pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
200        let path_str = input.cursor_path.to_string_lossy();
201        write!(prompt, "<|file_sep|>{}\n", path_str).ok();
202
203        prompt.push_str("<|fim_prefix|>\n");
204        prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
205        if !prompt.ends_with('\n') {
206            prompt.push('\n');
207        }
208
209        prompt.push_str("<|fim_middle|>current\n");
210        prompt.push_str(
211            &input.cursor_excerpt
212                [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
213        );
214        prompt.push_str(CURSOR_MARKER);
215        prompt.push_str(
216            &input.cursor_excerpt
217                [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
218        );
219        if !prompt.ends_with('\n') {
220            prompt.push('\n');
221        }
222
223        prompt.push_str("<|fim_suffix|>\n");
224        prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
225        if !prompt.ends_with('\n') {
226            prompt.push('\n');
227        }
228
229        prompt.push_str("<|fim_middle|>updated\n");
230    }
231}