1use crate::{
2 PromptFormat,
3 example::{Example, ExamplePrompt},
4 headless::EpAppState,
5 load_project::run_load_project,
6 progress::{Progress, Step},
7 retrieve_context::run_context_retrieval,
8};
9use anyhow::{Context as _, Result, ensure};
10use edit_prediction::{
11 EditPredictionStore,
12 zeta2::{zeta2_output_for_patch, zeta2_prompt_input},
13};
14use gpui::AsyncApp;
15use std::sync::Arc;
16use zeta_prompt::format_zeta_prompt;
17
18pub async fn run_format_prompt(
19 example: &mut Example,
20 prompt_format: PromptFormat,
21 app_state: Arc<EpAppState>,
22 mut cx: AsyncApp,
23) -> Result<()> {
24 run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
25
26 let _step_progress = Progress::global().start(Step::FormatPrompt, &example.spec.name);
27
28 match prompt_format {
29 PromptFormat::Teacher => {
30 let prompt = TeacherPrompt::format_prompt(example);
31 example.prompt = Some(ExamplePrompt {
32 input: prompt,
33 expected_output: example
34 .spec
35 .expected_patches
36 .first()
37 .cloned()
38 .unwrap_or_default(),
39 format: prompt_format,
40 });
41 }
42 PromptFormat::Zeta2 => {
43 run_load_project(example, app_state, cx.clone()).await?;
44
45 let ep_store = cx.update(|cx| {
46 EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
47 })??;
48
49 let state = example.state.as_ref().context("state must be set")?;
50 let snapshot = state.buffer.read_with(&cx, |buffer, _| buffer.snapshot())?;
51 let project = state.project.clone();
52 let (_, input) = ep_store.update(&mut cx, |ep_store, cx| {
53 let events = ep_store
54 .edit_history_for_project(&project, cx)
55 .into_iter()
56 .map(|e| e.event)
57 .collect();
58 anyhow::Ok(zeta2_prompt_input(
59 &snapshot,
60 example
61 .context
62 .as_ref()
63 .context("context must be set")?
64 .files
65 .clone(),
66 events,
67 example.spec.cursor_path.clone(),
68 example
69 .buffer
70 .as_ref()
71 .context("buffer must be set")?
72 .cursor_offset,
73 ))
74 })??;
75 let prompt = format_zeta_prompt(&input);
76 let expected_output = zeta2_output_for_patch(
77 &input,
78 &example
79 .spec
80 .expected_patches
81 .first()
82 .context("expected patches is empty")?
83 .clone(),
84 )?;
85 example.prompt = Some(ExamplePrompt {
86 input: prompt,
87 expected_output,
88 format: prompt_format,
89 });
90 }
91 };
92 Ok(())
93}
94
95pub struct TeacherPrompt;
96
97impl TeacherPrompt {
98 const PROMPT: &str = include_str!("teacher.prompt.md");
99 pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
100 pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>";
101 pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>";
102
103 /// Truncate edit history to this number of last lines
104 const MAX_HISTORY_LINES: usize = 128;
105
106 pub fn format_prompt(example: &Example) -> String {
107 let edit_history = Self::format_edit_history(&example.spec.edit_history);
108 let context = Self::format_context(example);
109 let editable_region = Self::format_editable_region(example);
110
111 let prompt = Self::PROMPT
112 .replace("{{context}}", &context)
113 .replace("{{edit_history}}", &edit_history)
114 .replace("{{editable_region}}", &editable_region);
115
116 prompt
117 }
118
119 pub fn parse(example: &Example, response: &str) -> Result<String> {
120 // Ideally, we should always be able to find cursor position in the retrieved context.
121 // In reality, sometimes we don't find it for these reasons:
122 // 1. `example.cursor_position` contains _more_ context than included in the retrieved context
123 // (can be fixed by getting cursor coordinates at the load_example stage)
124 // 2. Context retriever just didn't include cursor line.
125 //
126 // In that case, fallback to using `cursor_position` as excerpt.
127 let cursor_file = &example
128 .buffer
129 .as_ref()
130 .context("`buffer` should be filled in in the context collection step")?
131 .content;
132
133 // Extract updated (new) editable region from the model response
134 let new_editable_region = extract_last_codeblock(response);
135
136 // Reconstruct old editable region we sent to the model
137 let old_editable_region = Self::format_editable_region(example);
138 let old_editable_region = Self::extract_editable_region(&old_editable_region);
139 ensure!(
140 cursor_file.contains(&old_editable_region),
141 "Something's wrong: editable_region is not found in the cursor file"
142 );
143
144 // Apply editable region to a larger context and compute diff.
145 // This is needed to get a better context lines around the editable region
146 let edited_file = cursor_file.replace(&old_editable_region, &new_editable_region);
147 let diff = language::unified_diff(&cursor_file, &edited_file);
148
149 let diff = indoc::formatdoc! {"
150 --- a/{path}
151 +++ b/{path}
152 {diff}",
153 path = example.spec.cursor_path.to_string_lossy(),
154 diff = diff,
155 };
156
157 Ok(diff)
158 }
159
160 fn format_edit_history(edit_history: &str) -> String {
161 // Strip comments ("garbage lines") from edit history
162 let lines = edit_history
163 .lines()
164 .filter(|&s| Self::is_udiff_content_line(s))
165 .collect::<Vec<_>>();
166
167 let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
168 &lines[lines.len() - Self::MAX_HISTORY_LINES..]
169 } else {
170 &lines
171 };
172
173 if history_lines.is_empty() {
174 return "(No edit history)".to_string();
175 }
176
177 history_lines.join("\n")
178 }
179
180 fn format_context(example: &Example) -> String {
181 assert!(example.context.is_some(), "Missing context retriever step");
182
183 let mut prompt = String::new();
184 zeta_prompt::write_related_files(&mut prompt, &example.context.as_ref().unwrap().files);
185
186 prompt
187 }
188
189 fn format_editable_region(example: &Example) -> String {
190 let mut result = String::new();
191
192 let path_str = example.spec.cursor_path.to_string_lossy();
193 result.push_str(&format!("`````path=\"{path_str}\"\n"));
194 result.push_str(Self::EDITABLE_REGION_START);
195
196 // TODO: control number of lines around cursor
197 let (mut excerpt, offset) = example.spec.cursor_excerpt().unwrap();
198 excerpt.insert_str(offset, Self::USER_CURSOR_MARKER);
199 result.push_str(&excerpt);
200 if !result.ends_with('\n') {
201 result.push('\n');
202 }
203
204 result.push_str(Self::EDITABLE_REGION_END);
205 result.push_str("\n`````");
206
207 result
208 }
209
210 fn extract_editable_region(text: &str) -> String {
211 let start = text
212 .find(Self::EDITABLE_REGION_START)
213 .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
214 let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
215
216 let region = &text[start..end];
217
218 region.replace("<|user_cursor|>", "")
219 }
220
221 fn is_udiff_content_line(s: &str) -> bool {
222 s.starts_with("-")
223 || s.starts_with("+")
224 || s.starts_with(" ")
225 || s.starts_with("---")
226 || s.starts_with("+++")
227 || s.starts_with("@@")
228 }
229}
230
231fn extract_last_codeblock(text: &str) -> String {
232 let mut last_block = None;
233 let mut search_start = 0;
234
235 while let Some(start) = text[search_start..].find("```") {
236 let start = start + search_start;
237 let bytes = text.as_bytes();
238 let mut backtick_end = start;
239
240 while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
241 backtick_end += 1;
242 }
243
244 let backtick_count = backtick_end - start;
245 let closing_backticks = "`".repeat(backtick_count);
246
247 while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
248 backtick_end += 1;
249 }
250
251 if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
252 let code_block = &text[backtick_end + 1..backtick_end + end_pos];
253 last_block = Some(code_block.to_string());
254 search_start = backtick_end + end_pos + backtick_count;
255 } else {
256 break;
257 }
258 }
259
260 last_block.unwrap_or_else(|| text.to_string())
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266
267 #[test]
268 fn test_extract_last_code_block() {
269 let text = indoc::indoc! {"
270 Some thinking
271
272 ```
273 first block
274 ```
275
276 `````path='something' lines=1:2
277 last block
278 `````
279 "};
280 let last_block = extract_last_codeblock(text);
281 assert_eq!(last_block, "last block\n");
282 }
283
284 #[test]
285 fn test_extract_editable_region() {
286 let text = indoc::indoc! {"
287 some lines
288 are
289 here
290 <|editable_region_start|>
291 one
292 two three
293
294 <|editable_region_end|>
295 more
296 lines here
297 "};
298 let parsed = TeacherPrompt::extract_editable_region(text);
299 assert_eq!(
300 parsed,
301 indoc::indoc! {"
302 one
303 two three
304
305 "}
306 );
307 }
308}