1use crate::{
2 PromptFormat,
3 example::{Example, ExamplePrompt},
4 headless::EpAppState,
5 load_project::run_load_project,
6 progress::{Progress, Step},
7 retrieve_context::run_context_retrieval,
8};
9use anyhow::{Context as _, Result, ensure};
10use edit_prediction::{
11 EditPredictionStore,
12 zeta2::{zeta2_output_for_patch, zeta2_prompt_input},
13};
14use gpui::AsyncApp;
15use std::sync::Arc;
16use zeta_prompt::format_zeta_prompt;
17
18pub async fn run_format_prompt(
19 example: &mut Example,
20 prompt_format: PromptFormat,
21 app_state: Arc<EpAppState>,
22 mut cx: AsyncApp,
23) -> Result<()> {
24 run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
25
26 let _step_progress = Progress::global().start(Step::FormatPrompt, &example.spec.name);
27
28 match prompt_format {
29 PromptFormat::Teacher => {
30 let prompt = TeacherPrompt::format_prompt(example);
31 example.prompt = Some(ExamplePrompt {
32 input: prompt,
33 expected_output: example.spec.expected_patch.clone(), // TODO
34 format: prompt_format,
35 });
36 }
37 PromptFormat::Zeta2 => {
38 run_load_project(example, app_state, cx.clone()).await?;
39
40 let ep_store = cx.update(|cx| {
41 EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
42 })??;
43
44 let state = example.state.as_ref().context("state must be set")?;
45 let snapshot = state.buffer.read_with(&cx, |buffer, _| buffer.snapshot())?;
46 let project = state.project.clone();
47 let (_, input) = ep_store.update(&mut cx, |ep_store, cx| {
48 let events = ep_store
49 .edit_history_for_project(&project, cx)
50 .into_iter()
51 .map(|e| e.event)
52 .collect();
53 anyhow::Ok(zeta2_prompt_input(
54 &snapshot,
55 example
56 .context
57 .as_ref()
58 .context("context must be set")?
59 .files
60 .clone(),
61 events,
62 example.spec.cursor_path.clone(),
63 example
64 .buffer
65 .as_ref()
66 .context("buffer must be set")?
67 .cursor_offset,
68 ))
69 })??;
70 let prompt = format_zeta_prompt(&input);
71 let expected_output =
72 zeta2_output_for_patch(&input, &example.spec.expected_patch.clone())?;
73 example.prompt = Some(ExamplePrompt {
74 input: prompt,
75 expected_output,
76 format: prompt_format,
77 });
78 }
79 };
80 Ok(())
81}
82
83pub struct TeacherPrompt;
84
85impl TeacherPrompt {
86 const PROMPT: &str = include_str!("teacher.prompt.md");
87 pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
88 pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>";
89 pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>";
90
91 /// Truncate edit history to this number of last lines
92 const MAX_HISTORY_LINES: usize = 128;
93
94 pub fn format_prompt(example: &Example) -> String {
95 let edit_history = Self::format_edit_history(&example.spec.edit_history);
96 let context = Self::format_context(example);
97 let editable_region = Self::format_editable_region(example);
98
99 let prompt = Self::PROMPT
100 .replace("{{context}}", &context)
101 .replace("{{edit_history}}", &edit_history)
102 .replace("{{editable_region}}", &editable_region);
103
104 prompt
105 }
106
107 pub fn parse(example: &Example, response: &str) -> Result<String> {
108 // Ideally, we should always be able to find cursor position in the retrieved context.
109 // In reality, sometimes we don't find it for these reasons:
110 // 1. `example.cursor_position` contains _more_ context than included in the retrieved context
111 // (can be fixed by getting cursor coordinates at the load_example stage)
112 // 2. Context retriever just didn't include cursor line.
113 //
114 // In that case, fallback to using `cursor_position` as excerpt.
115 let cursor_file = &example
116 .buffer
117 .as_ref()
118 .context("`buffer` should be filled in in the context collection step")?
119 .content;
120
121 // Extract updated (new) editable region from the model response
122 let new_editable_region = extract_last_codeblock(response);
123
124 // Reconstruct old editable region we sent to the model
125 let old_editable_region = Self::format_editable_region(example);
126 let old_editable_region = Self::extract_editable_region(&old_editable_region);
127 ensure!(
128 cursor_file.contains(&old_editable_region),
129 "Something's wrong: editable_region is not found in the cursor file"
130 );
131
132 // Apply editable region to a larger context and compute diff.
133 // This is needed to get a better context lines around the editable region
134 let edited_file = cursor_file.replace(&old_editable_region, &new_editable_region);
135 let diff = language::unified_diff(&cursor_file, &edited_file);
136
137 let diff = indoc::formatdoc! {"
138 --- a/{path}
139 +++ b/{path}
140 {diff}",
141 path = example.spec.cursor_path.to_string_lossy(),
142 diff = diff,
143 };
144
145 Ok(diff)
146 }
147
148 fn format_edit_history(edit_history: &str) -> String {
149 // Strip comments ("garbage lines") from edit history
150 let lines = edit_history
151 .lines()
152 .filter(|&s| Self::is_udiff_content_line(s))
153 .collect::<Vec<_>>();
154
155 let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
156 &lines[lines.len() - Self::MAX_HISTORY_LINES..]
157 } else {
158 &lines
159 };
160
161 if history_lines.is_empty() {
162 return "(No edit history)".to_string();
163 }
164
165 history_lines.join("\n")
166 }
167
168 fn format_context(example: &Example) -> String {
169 assert!(example.context.is_some(), "Missing context retriever step");
170
171 let mut prompt = String::new();
172 zeta_prompt::write_related_files(&mut prompt, &example.context.as_ref().unwrap().files);
173
174 prompt
175 }
176
177 fn format_editable_region(example: &Example) -> String {
178 let mut result = String::new();
179
180 let path_str = example.spec.cursor_path.to_string_lossy();
181 result.push_str(&format!("`````path=\"{path_str}\"\n"));
182 result.push_str(Self::EDITABLE_REGION_START);
183
184 // TODO: control number of lines around cursor
185 let (mut excerpt, offset) = example.spec.cursor_excerpt().unwrap();
186 excerpt.insert_str(offset, Self::USER_CURSOR_MARKER);
187 result.push_str(&excerpt);
188 if !result.ends_with('\n') {
189 result.push('\n');
190 }
191
192 result.push_str(Self::EDITABLE_REGION_END);
193 result.push_str("\n`````");
194
195 result
196 }
197
198 fn extract_editable_region(text: &str) -> String {
199 let start = text
200 .find(Self::EDITABLE_REGION_START)
201 .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
202 let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
203
204 let region = &text[start..end];
205
206 region.replace("<|user_cursor|>", "")
207 }
208
209 fn is_udiff_content_line(s: &str) -> bool {
210 s.starts_with("-")
211 || s.starts_with("+")
212 || s.starts_with(" ")
213 || s.starts_with("---")
214 || s.starts_with("+++")
215 || s.starts_with("@@")
216 }
217}
218
219fn extract_last_codeblock(text: &str) -> String {
220 let mut last_block = None;
221 let mut search_start = 0;
222
223 while let Some(start) = text[search_start..].find("```") {
224 let start = start + search_start;
225 let bytes = text.as_bytes();
226 let mut backtick_end = start;
227
228 while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
229 backtick_end += 1;
230 }
231
232 let backtick_count = backtick_end - start;
233 let closing_backticks = "`".repeat(backtick_count);
234
235 while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
236 backtick_end += 1;
237 }
238
239 if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
240 let code_block = &text[backtick_end + 1..backtick_end + end_pos];
241 last_block = Some(code_block.to_string());
242 search_start = backtick_end + end_pos + backtick_count;
243 } else {
244 break;
245 }
246 }
247
248 last_block.unwrap_or_else(|| text.to_string())
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254
255 #[test]
256 fn test_extract_last_code_block() {
257 let text = indoc::indoc! {"
258 Some thinking
259
260 ```
261 first block
262 ```
263
264 `````path='something' lines=1:2
265 last block
266 `````
267 "};
268 let last_block = extract_last_codeblock(text);
269 assert_eq!(last_block, "last block\n");
270 }
271
272 #[test]
273 fn test_extract_editable_region() {
274 let text = indoc::indoc! {"
275 some lines
276 are
277 here
278 <|editable_region_start|>
279 one
280 two three
281
282 <|editable_region_end|>
283 more
284 lines here
285 "};
286 let parsed = TeacherPrompt::extract_editable_region(text);
287 assert_eq!(
288 parsed,
289 indoc::indoc! {"
290 one
291 two three
292
293 "}
294 );
295 }
296}