1use crate::{
2 PromptFormat,
3 example::{Example, ExamplePrompt},
4 headless::EpAppState,
5 load_project::run_load_project,
6 progress::{Progress, Step},
7 retrieve_context::run_context_retrieval,
8};
9use anyhow::{Context as _, Result, ensure};
10use edit_prediction::{
11 EditPredictionStore,
12 zeta2::{zeta2_output_for_patch, zeta2_prompt_input},
13};
14use gpui::AsyncApp;
15use std::sync::Arc;
16use zeta_prompt::format_zeta_prompt;
17
18pub async fn run_format_prompt(
19 example: &mut Example,
20 prompt_format: PromptFormat,
21 app_state: Arc<EpAppState>,
22 mut cx: AsyncApp,
23) -> Result<()> {
24 run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
25
26 let _step_progress = Progress::global().start(Step::FormatPrompt, &example.spec.name);
27
28 match prompt_format {
29 PromptFormat::Teacher => {
30 let prompt = TeacherPrompt::format_prompt(example);
31 example.prompt = Some(ExamplePrompt {
32 input: prompt,
33 // TODO
34 expected_output: example
35 .spec
36 .expected_patches
37 .first()
38 .context("no expected patches")?
39 .clone(),
40 format: prompt_format,
41 });
42 }
43 PromptFormat::Zeta2 => {
44 run_load_project(example, app_state, cx.clone()).await?;
45
46 let ep_store = cx.update(|cx| {
47 EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
48 })??;
49
50 let state = example.state.as_ref().context("state must be set")?;
51 let snapshot = state.buffer.read_with(&cx, |buffer, _| buffer.snapshot())?;
52 let project = state.project.clone();
53 let (_, input) = ep_store.update(&mut cx, |ep_store, cx| {
54 let events = ep_store
55 .edit_history_for_project(&project, cx)
56 .into_iter()
57 .map(|e| e.event)
58 .collect();
59 anyhow::Ok(zeta2_prompt_input(
60 &snapshot,
61 example
62 .context
63 .as_ref()
64 .context("context must be set")?
65 .files
66 .clone(),
67 events,
68 example.spec.cursor_path.clone(),
69 example
70 .buffer
71 .as_ref()
72 .context("buffer must be set")?
73 .cursor_offset,
74 ))
75 })??;
76 let prompt = format_zeta_prompt(&input);
77 let expected_output = zeta2_output_for_patch(
78 &input,
79 &example
80 .spec
81 .expected_patches
82 .first()
83 .context("expected patches is empty")?
84 .clone(),
85 )?;
86 example.prompt = Some(ExamplePrompt {
87 input: prompt,
88 expected_output,
89 format: prompt_format,
90 });
91 }
92 };
93 Ok(())
94}
95
96pub struct TeacherPrompt;
97
98impl TeacherPrompt {
99 const PROMPT: &str = include_str!("teacher.prompt.md");
100 pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
101 pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>";
102 pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>";
103
104 /// Truncate edit history to this number of last lines
105 const MAX_HISTORY_LINES: usize = 128;
106
107 pub fn format_prompt(example: &Example) -> String {
108 let edit_history = Self::format_edit_history(&example.spec.edit_history);
109 let context = Self::format_context(example);
110 let editable_region = Self::format_editable_region(example);
111
112 let prompt = Self::PROMPT
113 .replace("{{context}}", &context)
114 .replace("{{edit_history}}", &edit_history)
115 .replace("{{editable_region}}", &editable_region);
116
117 prompt
118 }
119
120 pub fn parse(example: &Example, response: &str) -> Result<String> {
121 // Ideally, we should always be able to find cursor position in the retrieved context.
122 // In reality, sometimes we don't find it for these reasons:
123 // 1. `example.cursor_position` contains _more_ context than included in the retrieved context
124 // (can be fixed by getting cursor coordinates at the load_example stage)
125 // 2. Context retriever just didn't include cursor line.
126 //
127 // In that case, fallback to using `cursor_position` as excerpt.
128 let cursor_file = &example
129 .buffer
130 .as_ref()
131 .context("`buffer` should be filled in in the context collection step")?
132 .content;
133
134 // Extract updated (new) editable region from the model response
135 let new_editable_region = extract_last_codeblock(response);
136
137 // Reconstruct old editable region we sent to the model
138 let old_editable_region = Self::format_editable_region(example);
139 let old_editable_region = Self::extract_editable_region(&old_editable_region);
140 ensure!(
141 cursor_file.contains(&old_editable_region),
142 "Something's wrong: editable_region is not found in the cursor file"
143 );
144
145 // Apply editable region to a larger context and compute diff.
146 // This is needed to get a better context lines around the editable region
147 let edited_file = cursor_file.replace(&old_editable_region, &new_editable_region);
148 let diff = language::unified_diff(&cursor_file, &edited_file);
149
150 let diff = indoc::formatdoc! {"
151 --- a/{path}
152 +++ b/{path}
153 {diff}",
154 path = example.spec.cursor_path.to_string_lossy(),
155 diff = diff,
156 };
157
158 Ok(diff)
159 }
160
161 fn format_edit_history(edit_history: &str) -> String {
162 // Strip comments ("garbage lines") from edit history
163 let lines = edit_history
164 .lines()
165 .filter(|&s| Self::is_udiff_content_line(s))
166 .collect::<Vec<_>>();
167
168 let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
169 &lines[lines.len() - Self::MAX_HISTORY_LINES..]
170 } else {
171 &lines
172 };
173
174 if history_lines.is_empty() {
175 return "(No edit history)".to_string();
176 }
177
178 history_lines.join("\n")
179 }
180
181 fn format_context(example: &Example) -> String {
182 assert!(example.context.is_some(), "Missing context retriever step");
183
184 let mut prompt = String::new();
185 zeta_prompt::write_related_files(&mut prompt, &example.context.as_ref().unwrap().files);
186
187 prompt
188 }
189
190 fn format_editable_region(example: &Example) -> String {
191 let mut result = String::new();
192
193 let path_str = example.spec.cursor_path.to_string_lossy();
194 result.push_str(&format!("`````path=\"{path_str}\"\n"));
195 result.push_str(Self::EDITABLE_REGION_START);
196
197 // TODO: control number of lines around cursor
198 let (mut excerpt, offset) = example.spec.cursor_excerpt().unwrap();
199 excerpt.insert_str(offset, Self::USER_CURSOR_MARKER);
200 result.push_str(&excerpt);
201 if !result.ends_with('\n') {
202 result.push('\n');
203 }
204
205 result.push_str(Self::EDITABLE_REGION_END);
206 result.push_str("\n`````");
207
208 result
209 }
210
211 fn extract_editable_region(text: &str) -> String {
212 let start = text
213 .find(Self::EDITABLE_REGION_START)
214 .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
215 let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
216
217 let region = &text[start..end];
218
219 region.replace("<|user_cursor|>", "")
220 }
221
222 fn is_udiff_content_line(s: &str) -> bool {
223 s.starts_with("-")
224 || s.starts_with("+")
225 || s.starts_with(" ")
226 || s.starts_with("---")
227 || s.starts_with("+++")
228 || s.starts_with("@@")
229 }
230}
231
232fn extract_last_codeblock(text: &str) -> String {
233 let mut last_block = None;
234 let mut search_start = 0;
235
236 while let Some(start) = text[search_start..].find("```") {
237 let start = start + search_start;
238 let bytes = text.as_bytes();
239 let mut backtick_end = start;
240
241 while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
242 backtick_end += 1;
243 }
244
245 let backtick_count = backtick_end - start;
246 let closing_backticks = "`".repeat(backtick_count);
247
248 while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
249 backtick_end += 1;
250 }
251
252 if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
253 let code_block = &text[backtick_end + 1..backtick_end + end_pos];
254 last_block = Some(code_block.to_string());
255 search_start = backtick_end + end_pos + backtick_count;
256 } else {
257 break;
258 }
259 }
260
261 last_block.unwrap_or_else(|| text.to_string())
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267
268 #[test]
269 fn test_extract_last_code_block() {
270 let text = indoc::indoc! {"
271 Some thinking
272
273 ```
274 first block
275 ```
276
277 `````path='something' lines=1:2
278 last block
279 `````
280 "};
281 let last_block = extract_last_codeblock(text);
282 assert_eq!(last_block, "last block\n");
283 }
284
285 #[test]
286 fn test_extract_editable_region() {
287 let text = indoc::indoc! {"
288 some lines
289 are
290 here
291 <|editable_region_start|>
292 one
293 two three
294
295 <|editable_region_end|>
296 more
297 lines here
298 "};
299 let parsed = TeacherPrompt::extract_editable_region(text);
300 assert_eq!(
301 parsed,
302 indoc::indoc! {"
303 one
304 two three
305
306 "}
307 );
308 }
309}