1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use std::fmt::Write;
4use std::ops::Range;
5use std::path::Path;
6use std::sync::Arc;
7use strum::{EnumIter, IntoEnumIterator as _, IntoStaticStr};
8
9pub const CURSOR_MARKER: &str = "<|user_cursor|>";
10
11#[derive(Clone, Debug, Serialize, Deserialize)]
12pub struct ZetaPromptInput {
13 pub cursor_path: Arc<Path>,
14 pub cursor_excerpt: Arc<str>,
15 pub editable_range_in_excerpt: Range<usize>,
16 pub cursor_offset_in_excerpt: usize,
17 pub events: Vec<Arc<Event>>,
18 pub related_files: Vec<RelatedFile>,
19}
20
21#[derive(Default, Clone, Copy, Debug, PartialEq, Eq, EnumIter, IntoStaticStr)]
22#[allow(non_camel_case_types)]
23pub enum ZetaVersion {
24 V0112_MiddleAtEnd,
25 #[default]
26 V0113_Ordered,
27}
28
29impl std::fmt::Display for ZetaVersion {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 write!(f, "{}", <&'static str>::from(self))
32 }
33}
34
35impl ZetaVersion {
36 pub fn parse(version_string: &str) -> Result<Self> {
37 let mut results = ZetaVersion::iter().filter(|version| {
38 <&'static str>::from(version)
39 .to_lowercase()
40 .contains(&version_string.to_lowercase())
41 });
42 let Some(result) = results.next() else {
43 anyhow::bail!(
44 "`{version_string}` did not match any of:\n{}",
45 Self::options_as_string()
46 );
47 };
48 if results.next().is_some() {
49 anyhow::bail!(
50 "`{version_string}` matched more than one of:\n{}",
51 Self::options_as_string()
52 );
53 }
54 Ok(result)
55 }
56
57 fn options_as_string() -> String {
58 ZetaVersion::iter()
59 .map(|version| format!("- {}\n", <&'static str>::from(version)))
60 .collect::<Vec<_>>()
61 .concat()
62 }
63
64 pub fn default_as_string() -> String {
65 <&'static str>::from(Self::default()).to_string()
66 }
67}
68
69#[derive(Clone, Debug, Serialize, Deserialize)]
70#[serde(tag = "event")]
71pub enum Event {
72 BufferChange {
73 path: Arc<Path>,
74 old_path: Arc<Path>,
75 diff: String,
76 predicted: bool,
77 in_open_source_repo: bool,
78 },
79}
80
81pub fn write_event(prompt: &mut String, event: &Event) {
82 fn write_path_as_unix_str(prompt: &mut String, path: &Path) {
83 for component in path.components() {
84 prompt.push('/');
85 write!(prompt, "{}", component.as_os_str().display()).ok();
86 }
87 }
88 match event {
89 Event::BufferChange {
90 path,
91 old_path,
92 diff,
93 predicted,
94 in_open_source_repo: _,
95 } => {
96 if *predicted {
97 prompt.push_str("// User accepted prediction:\n");
98 }
99 prompt.push_str("--- a");
100 write_path_as_unix_str(prompt, old_path.as_ref());
101 prompt.push_str("\n+++ b");
102 write_path_as_unix_str(prompt, path.as_ref());
103 prompt.push('\n');
104 prompt.push_str(diff);
105 }
106 }
107}
108
109#[derive(Clone, Debug, Serialize, Deserialize)]
110pub struct RelatedFile {
111 pub path: Arc<Path>,
112 pub max_row: u32,
113 pub excerpts: Vec<RelatedExcerpt>,
114}
115
116#[derive(Clone, Debug, Serialize, Deserialize)]
117pub struct RelatedExcerpt {
118 pub row_range: Range<u32>,
119 pub text: Arc<str>,
120}
121
122pub fn format_zeta_prompt(input: &ZetaPromptInput, version: ZetaVersion) -> String {
123 let mut prompt = String::new();
124 write_related_files(&mut prompt, &input.related_files);
125 write_edit_history_section(&mut prompt, input);
126
127 match version {
128 ZetaVersion::V0112_MiddleAtEnd => {
129 v0112_middle_at_end::write_cursor_excerpt_section(&mut prompt, input);
130 }
131 ZetaVersion::V0113_Ordered => {
132 v0113_ordered::write_cursor_excerpt_section(&mut prompt, input)
133 }
134 }
135
136 prompt
137}
138
139pub fn write_related_files(prompt: &mut String, related_files: &[RelatedFile]) {
140 for file in related_files {
141 let path_str = file.path.to_string_lossy();
142 write!(prompt, "<|file_sep|>{}\n", path_str).ok();
143 for excerpt in &file.excerpts {
144 prompt.push_str(&excerpt.text);
145 if !prompt.ends_with('\n') {
146 prompt.push('\n');
147 }
148 if excerpt.row_range.end < file.max_row {
149 prompt.push_str("...\n");
150 }
151 }
152 }
153}
154
155fn write_edit_history_section(prompt: &mut String, input: &ZetaPromptInput) {
156 prompt.push_str("<|file_sep|>edit history\n");
157 for event in &input.events {
158 write_event(prompt, event);
159 }
160}
161
162mod v0112_middle_at_end {
163 use super::*;
164
165 pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
166 let path_str = input.cursor_path.to_string_lossy();
167 write!(prompt, "<|file_sep|>{}\n", path_str).ok();
168
169 prompt.push_str("<|fim_prefix|>\n");
170 prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
171
172 prompt.push_str("<|fim_suffix|>\n");
173 prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
174 if !prompt.ends_with('\n') {
175 prompt.push('\n');
176 }
177
178 prompt.push_str("<|fim_middle|>current\n");
179 prompt.push_str(
180 &input.cursor_excerpt
181 [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
182 );
183 prompt.push_str(CURSOR_MARKER);
184 prompt.push_str(
185 &input.cursor_excerpt
186 [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
187 );
188 if !prompt.ends_with('\n') {
189 prompt.push('\n');
190 }
191
192 prompt.push_str("<|fim_middle|>updated\n");
193 }
194}
195
196mod v0113_ordered {
197 use super::*;
198
199 pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
200 let path_str = input.cursor_path.to_string_lossy();
201 write!(prompt, "<|file_sep|>{}\n", path_str).ok();
202
203 prompt.push_str("<|fim_prefix|>\n");
204 prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
205 if !prompt.ends_with('\n') {
206 prompt.push('\n');
207 }
208
209 prompt.push_str("<|fim_middle|>current\n");
210 prompt.push_str(
211 &input.cursor_excerpt
212 [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
213 );
214 prompt.push_str(CURSOR_MARKER);
215 prompt.push_str(
216 &input.cursor_excerpt
217 [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
218 );
219 if !prompt.ends_with('\n') {
220 prompt.push('\n');
221 }
222
223 prompt.push_str("<|fim_suffix|>\n");
224 prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
225 if !prompt.ends_with('\n') {
226 prompt.push('\n');
227 }
228
229 prompt.push_str("<|fim_middle|>updated\n");
230 }
231}