1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use std::fmt::Write;
4use std::ops::Range;
5use std::path::Path;
6use std::sync::Arc;
7use strum::{EnumIter, IntoEnumIterator as _, IntoStaticStr};
8
9pub const CURSOR_MARKER: &str = "<|user_cursor|>";
10
11#[derive(Clone, Debug, Serialize, Deserialize)]
12pub struct ZetaPromptInput {
13 pub cursor_path: Arc<Path>,
14 pub cursor_excerpt: Arc<str>,
15 pub editable_range_in_excerpt: Range<usize>,
16 pub cursor_offset_in_excerpt: usize,
17 pub events: Vec<Arc<Event>>,
18 pub related_files: Vec<RelatedFile>,
19}
20
21#[derive(
22 Default,
23 Clone,
24 Copy,
25 Debug,
26 PartialEq,
27 Eq,
28 Hash,
29 EnumIter,
30 IntoStaticStr,
31 Serialize,
32 Deserialize,
33)]
34#[allow(non_camel_case_types)]
35pub enum ZetaVersion {
36 V0112MiddleAtEnd,
37 V0113Ordered,
38 #[default]
39 V0114180EditableRegion,
40 V0120GitMergeMarkers,
41}
42
43impl std::fmt::Display for ZetaVersion {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 write!(f, "{}", <&'static str>::from(self))
46 }
47}
48
49impl ZetaVersion {
50 pub fn parse(version_string: &str) -> Result<Self> {
51 let mut results = ZetaVersion::iter().filter(|version| {
52 <&'static str>::from(version)
53 .to_lowercase()
54 .contains(&version_string.to_lowercase())
55 });
56 let Some(result) = results.next() else {
57 anyhow::bail!(
58 "`{version_string}` did not match any of:\n{}",
59 Self::options_as_string()
60 );
61 };
62 if results.next().is_some() {
63 anyhow::bail!(
64 "`{version_string}` matched more than one of:\n{}",
65 Self::options_as_string()
66 );
67 }
68 Ok(result)
69 }
70
71 pub fn options_as_string() -> String {
72 ZetaVersion::iter()
73 .map(|version| format!("- {}\n", <&'static str>::from(version)))
74 .collect::<Vec<_>>()
75 .concat()
76 }
77}
78
79#[derive(Clone, Debug, Serialize, Deserialize)]
80#[serde(tag = "event")]
81pub enum Event {
82 BufferChange {
83 path: Arc<Path>,
84 old_path: Arc<Path>,
85 diff: String,
86 predicted: bool,
87 in_open_source_repo: bool,
88 },
89}
90
91pub fn write_event(prompt: &mut String, event: &Event) {
92 fn write_path_as_unix_str(prompt: &mut String, path: &Path) {
93 for component in path.components() {
94 prompt.push('/');
95 write!(prompt, "{}", component.as_os_str().display()).ok();
96 }
97 }
98 match event {
99 Event::BufferChange {
100 path,
101 old_path,
102 diff,
103 predicted,
104 in_open_source_repo: _,
105 } => {
106 if *predicted {
107 prompt.push_str("// User accepted prediction:\n");
108 }
109 prompt.push_str("--- a");
110 write_path_as_unix_str(prompt, old_path.as_ref());
111 prompt.push_str("\n+++ b");
112 write_path_as_unix_str(prompt, path.as_ref());
113 prompt.push('\n');
114 prompt.push_str(diff);
115 }
116 }
117}
118
119#[derive(Clone, Debug, Serialize, Deserialize)]
120pub struct RelatedFile {
121 pub path: Arc<Path>,
122 pub max_row: u32,
123 pub excerpts: Vec<RelatedExcerpt>,
124}
125
126#[derive(Clone, Debug, Serialize, Deserialize)]
127pub struct RelatedExcerpt {
128 pub row_range: Range<u32>,
129 pub text: Arc<str>,
130}
131
132pub fn format_zeta_prompt(input: &ZetaPromptInput, version: ZetaVersion) -> String {
133 let mut prompt = String::new();
134 write_related_files(&mut prompt, &input.related_files);
135 write_edit_history_section(&mut prompt, input);
136
137 match version {
138 ZetaVersion::V0112MiddleAtEnd => {
139 v0112_middle_at_end::write_cursor_excerpt_section(&mut prompt, input);
140 }
141 ZetaVersion::V0113Ordered | ZetaVersion::V0114180EditableRegion => {
142 v0113_ordered::write_cursor_excerpt_section(&mut prompt, input)
143 }
144
145 ZetaVersion::V0120GitMergeMarkers => {
146 v0120_git_merge_markers::write_cursor_excerpt_section(&mut prompt, input)
147 }
148 }
149
150 prompt
151}
152
153pub fn write_related_files(prompt: &mut String, related_files: &[RelatedFile]) {
154 for file in related_files {
155 let path_str = file.path.to_string_lossy();
156 write!(prompt, "<|file_sep|>{}\n", path_str).ok();
157 for excerpt in &file.excerpts {
158 prompt.push_str(&excerpt.text);
159 if !prompt.ends_with('\n') {
160 prompt.push('\n');
161 }
162 if excerpt.row_range.end < file.max_row {
163 prompt.push_str("...\n");
164 }
165 }
166 }
167}
168
169fn write_edit_history_section(prompt: &mut String, input: &ZetaPromptInput) {
170 prompt.push_str("<|file_sep|>edit history\n");
171 for event in &input.events {
172 write_event(prompt, event);
173 }
174}
175
176mod v0112_middle_at_end {
177 use super::*;
178
179 pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
180 let path_str = input.cursor_path.to_string_lossy();
181 write!(prompt, "<|file_sep|>{}\n", path_str).ok();
182
183 prompt.push_str("<|fim_prefix|>\n");
184 prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
185
186 prompt.push_str("<|fim_suffix|>\n");
187 prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
188 if !prompt.ends_with('\n') {
189 prompt.push('\n');
190 }
191
192 prompt.push_str("<|fim_middle|>current\n");
193 prompt.push_str(
194 &input.cursor_excerpt
195 [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
196 );
197 prompt.push_str(CURSOR_MARKER);
198 prompt.push_str(
199 &input.cursor_excerpt
200 [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
201 );
202 if !prompt.ends_with('\n') {
203 prompt.push('\n');
204 }
205
206 prompt.push_str("<|fim_middle|>updated\n");
207 }
208}
209
210mod v0113_ordered {
211 use super::*;
212
213 pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
214 let path_str = input.cursor_path.to_string_lossy();
215 write!(prompt, "<|file_sep|>{}\n", path_str).ok();
216
217 prompt.push_str("<|fim_prefix|>\n");
218 prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
219 if !prompt.ends_with('\n') {
220 prompt.push('\n');
221 }
222
223 prompt.push_str("<|fim_middle|>current\n");
224 prompt.push_str(
225 &input.cursor_excerpt
226 [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
227 );
228 prompt.push_str(CURSOR_MARKER);
229 prompt.push_str(
230 &input.cursor_excerpt
231 [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
232 );
233 if !prompt.ends_with('\n') {
234 prompt.push('\n');
235 }
236
237 prompt.push_str("<|fim_suffix|>\n");
238 prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
239 if !prompt.ends_with('\n') {
240 prompt.push('\n');
241 }
242
243 prompt.push_str("<|fim_middle|>updated\n");
244 }
245}
246
247pub mod v0120_git_merge_markers {
248 //! A prompt that uses git-style merge conflict markers to represent the editable region.
249 //!
250 //! Example prompt:
251 //!
252 //! <|file_sep|>path/to/target_file.py
253 //! <|fim_prefix|>
254 //! code before editable region
255 //! <|fim_suffix|>
256 //! code after editable region
257 //! <|fim_middle|>
258 //! <<<<<<< CURRENT
259 //! code that
260 //! needs to<|user_cursor|>
261 //! be rewritten
262 //! =======
263 //!
264 //! Expected output (should be generated by the model):
265 //!
266 //! updated
267 //! code with
268 //! changes applied
269 //! >>>>>>> UPDATED
270
271 use super::*;
272
273 pub const START_MARKER: &str = "<<<<<<< CURRENT\n";
274 pub const SEPARATOR: &str = "=======\n";
275 pub const END_MARKER: &str = ">>>>>>> UPDATED\n";
276
277 pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
278 let path_str = input.cursor_path.to_string_lossy();
279 write!(prompt, "<|file_sep|>{}\n", path_str).ok();
280
281 prompt.push_str("<|fim_prefix|>");
282 prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
283
284 prompt.push_str("<|fim_suffix|>");
285 prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
286 if !prompt.ends_with('\n') {
287 prompt.push('\n');
288 }
289
290 prompt.push_str("<|fim_middle|>");
291 prompt.push_str(START_MARKER);
292 prompt.push_str(
293 &input.cursor_excerpt
294 [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
295 );
296 prompt.push_str(CURSOR_MARKER);
297 prompt.push_str(
298 &input.cursor_excerpt
299 [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
300 );
301 if !prompt.ends_with('\n') {
302 prompt.push('\n');
303 }
304 prompt.push_str(SEPARATOR);
305 }
306}